• Docs >
  • Module code >
  • mmpretrain.engine.optimizers.layer_decay_optim_wrapper_constructor
Shortcuts

Source code for mmpretrain.engine.optimizers.layer_decay_optim_wrapper_constructor

# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import Callable, List, Optional

from mmengine.logging import MMLogger
from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from torch import nn
from torch.nn import GroupNorm, LayerNorm

from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS


[docs]@OPTIM_WRAPPER_CONSTRUCTORS.register_module() class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): """Different learning rates are set for different layers of backbone. By default, each parameter share the same optimizer settings, and we provide an argument ``paramwise_cfg`` to specify parameter-wise settings. It is a dict and may contain the following fields: - ``layer_decay_rate`` (float): The learning rate of a parameter will multiply it by multiple times according to the layer depth of the parameter. Usually, it's less than 1, so that the earlier layers will have a lower learning rate. Defaults to 1. - ``bias_decay_mult`` (float): It will be multiplied to the weight decay for all bias parameters (except for those in normalization layers). - ``norm_decay_mult`` (float): It will be multiplied to the weight decay for all weight and bias parameters of normalization layers. - ``flat_decay_mult`` (float): It will be multiplied to the weight decay for all one-dimensional parameters - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If one of the keys in ``custom_keys`` is a substring of the name of one parameter, then the setting of the parameter will be specified by ``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be ignored. It should be a dict and may contain fields ``decay_mult``. (The ``lr_mult`` is disabled in this constructor). Example: In the config file, you can use this constructor as below: .. code:: python optim_wrapper = dict( optimizer=dict( type='AdamW', lr=4e-3, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999)), constructor='LearningRateDecayOptimWrapperConstructor', paramwise_cfg=dict( layer_decay_rate=0.75, # layer-wise lr decay factor norm_decay_mult=0., flat_decay_mult=0., custom_keys={ '.cls_token': dict(decay_mult=0.0), '.pos_embed': dict(decay_mult=0.0) })) """
[docs] def add_params(self, params: List[dict], module: nn.Module, prefix: str = '', get_layer_depth: Optional[Callable] = None, **kwargs) -> None: """Add all parameters of module to the params list. The parameters of the given module will be added to the list of param groups, with specific rules defined by paramwise_cfg. Args: params (List[dict]): A list of param groups, it will be modified in place. module (nn.Module): The module to be added. optimizer_cfg (dict): The configuration of optimizer. prefix (str): The prefix of the module. """ # get param-wise options custom_keys = self.paramwise_cfg.get('custom_keys', {}) # first sort with alphabet order and then sort with reversed len of str sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) logger = MMLogger.get_current_instance() # The model should have `get_layer_depth` method if get_layer_depth is None and not hasattr(module, 'get_layer_depth'): raise NotImplementedError('The layer-wise learning rate decay need' f' the model {type(module)} has' ' `get_layer_depth` method.') else: get_layer_depth = get_layer_depth or module.get_layer_depth bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None) norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None) flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None) decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0) # special rules for norm layers and depth-wise conv layers is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) for name, param in module.named_parameters(recurse=False): param_group = {'params': [param]} param_name = prefix + name if not param.requires_grad: continue if self.base_wd is not None: base_wd = self.base_wd custom_key = next( filter(lambda k: k in param_name, sorted_keys), None) # custom parameters decay if custom_key is not None: custom_cfg = custom_keys[custom_key].copy() decay_mult = custom_cfg.pop('decay_mult', 1.) param_group['weight_decay'] = base_wd * decay_mult # add custom settings to param_group param_group.update(custom_cfg) # norm decay elif is_norm and norm_decay_mult is not None: param_group['weight_decay'] = base_wd * norm_decay_mult # bias decay elif name == 'bias' and bias_decay_mult is not None: param_group['weight_decay'] = base_wd * bias_decay_mult # flatten parameters decay elif param.ndim == 1 and flat_decay_mult is not None: param_group['weight_decay'] = base_wd * flat_decay_mult else: param_group['weight_decay'] = base_wd layer_id, max_id = get_layer_depth(param_name) scale = decay_rate**(max_id - layer_id - 1) param_group['lr'] = self.base_lr * scale param_group['lr_scale'] = scale param_group['layer_id'] = layer_id param_group['param_name'] = param_name params.append(param_group) for child_name, child_mod in module.named_children(): child_prefix = f'{prefix}{child_name}.' self.add_params( params, child_mod, prefix=child_prefix, get_layer_depth=get_layer_depth, ) if prefix == '': layer_params = defaultdict(list) for param in params: layer_params[param['layer_id']].append(param) for layer_id, layer_params in layer_params.items(): lr_scale = layer_params[0]['lr_scale'] lr = layer_params[0]['lr'] msg = [ f'layer {layer_id} params ' f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):' ] for param in layer_params: msg.append(f'\t{param["param_name"]}: ' f'weight_decay={param["weight_decay"]:.3g}') logger.debug('\n'.join(msg))
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.