Shortcuts

Note

You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.

Source code for mmselfsup.core.optimizer.constructor

# Copyright (c) OpenMMLab. All rights reserved.
import re

import torch.distributed as dist
from mmcv.runner.optimizer.builder import OPTIMIZER_BUILDERS, OPTIMIZERS
from mmcv.utils import build_from_cfg, print_log


[docs]@OPTIMIZER_BUILDERS.register_module(force=True) class DefaultOptimizerConstructor: """Rewrote default constructor for optimizers. By default each parameter share the same optimizer settings, and we provide an argument ``paramwise_cfg`` to specify parameter-wise settings. It is a dict and may contain the following fields: Args: model (:obj:`nn.Module`): The model with parameters to be optimized. optimizer_cfg (dict): The config dict of the optimizer. Positional fields are - `type`: class name of the optimizer. Optional fields are - any arguments of the corresponding optimizer type, e.g., lr, weight_decay, momentum, etc. paramwise_cfg (dict, optional): Parameter-wise options. Defaults to None. Example 1: >>> model = torch.nn.modules.Conv1d(1, 1, 1) >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, >>> weight_decay=0.0001) >>> paramwise_cfg = dict('bias': dict(weight_decay=0., \ lars_exclude=True)) >>> optim_builder = DefaultOptimizerConstructor( >>> optimizer_cfg, paramwise_cfg) >>> optimizer = optim_builder(model) """ def __init__(self, optimizer_cfg, paramwise_cfg=None): if not isinstance(optimizer_cfg, dict): raise TypeError('optimizer_cfg should be a dict', f'but got {type(optimizer_cfg)}') self.optimizer_cfg = optimizer_cfg self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg def __call__(self, model): if hasattr(model, 'module'): model = model.module optimizer_cfg = self.optimizer_cfg.copy() paramwise_options = self.paramwise_cfg # if no paramwise option is specified, just use the global setting if paramwise_options is None: optimizer_cfg['params'] = model.parameters() return build_from_cfg(optimizer_cfg, OPTIMIZERS) else: assert isinstance(paramwise_options, dict) params = [] for name, param in model.named_parameters(): param_group = {'params': [param]} if not param.requires_grad: params.append(param_group) continue for regexp, options in paramwise_options.items(): if re.search(regexp, name): for key, value in options.items(): if key.endswith('_mult'): # is a multiplier key = key[:-5] assert key in optimizer_cfg, \ f'{key} not in optimizer_cfg' value = optimizer_cfg[key] * value param_group[key] = value if not dist.is_initialized() or \ dist.get_rank() == 0: print_log(f'paramwise_options -- \ {name}: {key}={value}') # otherwise use the global settings params.append(param_group) optimizer_cfg['params'] = params return build_from_cfg(optimizer_cfg, OPTIMIZERS)
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.