Shortcuts

Source code for mmpretrain.models.utils.layer_scale

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union

import torch
import torch.nn as nn


[docs]class LayerScale(nn.Module): """LayerScale layer. Args: dim (int): Dimension of input features. layer_scale_init_value (float or torch.Tensor): Init value of layer scale. Defaults to 1e-5. inplace (bool): inplace: can optionally do the operation in-place. Defaults to False. data_format (str): The input data format, could be 'channels_last' or 'channels_first', representing (B, C, H, W) and (B, N, C) format data respectively. Defaults to 'channels_last'. """ def __init__(self, dim: int, layer_scale_init_value: Union[float, torch.Tensor] = 1e-5, inplace: bool = False, data_format: str = 'channels_last'): super().__init__() assert data_format in ('channels_last', 'channels_first'), \ "'data_format' could only be channels_last or channels_first." self.inplace = inplace self.data_format = data_format self.weight = nn.Parameter(torch.ones(dim) * layer_scale_init_value) def forward(self, x): if self.data_format == 'channels_first': if self.inplace: return x.mul_(self.weight.view(-1, 1, 1)) else: return x * self.weight.view(-1, 1, 1) return x.mul_(self.weight) if self.inplace else x * self.weight