Shortcuts

Source code for mmpretrain.models.backbones.deit3

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence

import numpy as np
import torch
from mmcv.cnn import Linear, build_activation_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.utils import deprecated_api_warning
from torch import nn

from mmpretrain.registry import MODELS
from ..utils import (LayerScale, MultiheadAttention, build_norm_layer,
                     resize_pos_embed, to_2tuple)
from .vision_transformer import VisionTransformer


class DeiT3FFN(BaseModule):
    """FFN for DeiT3.

    The differences between DeiT3FFN & FFN:
        1. Use LayerScale.

    Args:
        embed_dims (int): The feature dimension. Same as
            `MultiheadAttention`. Defaults: 256.
        feedforward_channels (int): The hidden dimension of FFNs.
            Defaults: 1024.
        num_fcs (int, optional): The number of fully-connected layers in
            FFNs. Default: 2.
        act_cfg (dict, optional): The activation config for FFNs.
            Default: dict(type='ReLU')
        ffn_drop (float, optional): Probability of an element to be
            zeroed in FFN. Default 0.0.
        add_identity (bool, optional): Whether to add the
            identity connection. Default: `True`.
        dropout_layer (obj:`ConfigDict`): The dropout_layer used
            when adding the shortcut.
        use_layer_scale (bool): Whether to use layer_scale in
            DeiT3FFN. Defaults to True.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    @deprecated_api_warning(
        {
            'dropout': 'ffn_drop',
            'add_residual': 'add_identity'
        },
        cls_name='FFN')
    def __init__(self,
                 embed_dims=256,
                 feedforward_channels=1024,
                 num_fcs=2,
                 act_cfg=dict(type='ReLU', inplace=True),
                 ffn_drop=0.,
                 dropout_layer=None,
                 add_identity=True,
                 use_layer_scale=True,
                 init_cfg=None,
                 **kwargs):
        super().__init__(init_cfg)
        assert num_fcs >= 2, 'num_fcs should be no less ' \
            f'than 2. got {num_fcs}.'
        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.num_fcs = num_fcs
        self.act_cfg = act_cfg
        self.activate = build_activation_layer(act_cfg)

        layers = []
        in_channels = embed_dims
        for _ in range(num_fcs - 1):
            layers.append(
                Sequential(
                    Linear(in_channels, feedforward_channels), self.activate,
                    nn.Dropout(ffn_drop)))
            in_channels = feedforward_channels
        layers.append(Linear(feedforward_channels, embed_dims))
        layers.append(nn.Dropout(ffn_drop))
        self.layers = Sequential(*layers)
        self.dropout_layer = build_dropout(
            dropout_layer) if dropout_layer else torch.nn.Identity()
        self.add_identity = add_identity

        if use_layer_scale:
            self.gamma2 = LayerScale(embed_dims)
        else:
            self.gamma2 = nn.Identity()

    @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
    def forward(self, x, identity=None):
        """Forward function for `FFN`.

        The function would add x to the output tensor if residue is None.
        """
        out = self.layers(x)
        out = self.gamma2(out)
        if not self.add_identity:
            return self.dropout_layer(out)
        if identity is None:
            identity = x
        return identity + self.dropout_layer(out)


class DeiT3TransformerEncoderLayer(BaseModule):
    """Implements one encoder layer in DeiT3.

    The differences between DeiT3TransformerEncoderLayer &
    TransformerEncoderLayer:
        1. Use LayerScale.

    Args:
        embed_dims (int): The feature dimension
        num_heads (int): Parallel attention heads
        feedforward_channels (int): The hidden dimension for FFNs
        drop_rate (float): Probability of an element to be zeroed
            after the feed forward layer. Defaults to 0.
        attn_drop_rate (float): The drop out rate for attention output weights.
            Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
        num_fcs (int): The number of fully-connected layers for FFNs.
            Defaults to 2.
        qkv_bias (bool): enable bias for qkv if True. Defaults to True.
        use_layer_scale (bool): Whether to use layer_scale in
            DeiT3TransformerEncoderLayer. Defaults to True.
        act_cfg (dict): The activation config for FFNs.
            Defaults to ``dict(type='GELU')``.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=True,
                 use_layer_scale=True,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super(DeiT3TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims

        self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)

        self.attn = MultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            qkv_bias=qkv_bias,
            use_layer_scale=use_layer_scale)

        self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)

        self.ffn = DeiT3FFN(
            embed_dims=embed_dims,
            feedforward_channels=feedforward_channels,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            act_cfg=act_cfg,
            use_layer_scale=use_layer_scale)

    def init_weights(self):
        super(DeiT3TransformerEncoderLayer, self).init_weights()
        for m in self.ffn.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.normal_(m.bias, std=1e-6)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = self.ffn(self.ln1(x), identity=x)
        return x


[docs]@MODELS.register_module() class DeiT3(VisionTransformer): """DeiT3 backbone. A PyTorch implement of : `DeiT III: Revenge of the ViT <https://arxiv.org/pdf/2204.07118.pdf>`_ The differences between DeiT3 & VisionTransformer: 1. Use LayerScale. 2. Concat cls token after adding pos_embed. Args: arch (str | dict): DeiT3 architecture. If use string, choose from 'small', 'base', 'medium', 'large' and 'huge'. If use dict, it should have below keys: - **embed_dims** (int): The dimensions of embedding. - **num_layers** (int): The number of transformer encoder layers. - **num_heads** (int): The number of heads in attention modules. - **feedforward_channels** (int): The hidden dimensions in feedforward modules. Defaults to 'base'. img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224. patch_size (int | tuple): The patch size in patch embedding. Defaults to 16. in_channels (int): The num of input channels. Defaults to 3. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. out_type (str): The type of output features. Please choose from - ``"cls_token"``: The class token tensor with shape (B, C). - ``"featmap"``: The feature map tensor from the patch tokens with shape (B, C, H, W). - ``"avg_featmap"``: The global averaged feature map tensor with shape (B, C). - ``"raw"``: The raw feature tensor includes patch tokens and class tokens with shape (B, L, C). Defaults to ``"cls_token"``. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. use_layer_scale (bool): Whether to use layer_scale in DeiT3. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ arch_zoo = { **dict.fromkeys( ['s', 'small'], { 'embed_dims': 384, 'num_layers': 12, 'num_heads': 6, 'feedforward_channels': 1536, }), **dict.fromkeys( ['m', 'medium'], { 'embed_dims': 512, 'num_layers': 12, 'num_heads': 8, 'feedforward_channels': 2048, }), **dict.fromkeys( ['b', 'base'], { 'embed_dims': 768, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 3072 }), **dict.fromkeys( ['l', 'large'], { 'embed_dims': 1024, 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 4096 }), **dict.fromkeys( ['h', 'huge'], { 'embed_dims': 1280, 'num_layers': 32, 'num_heads': 16, 'feedforward_channels': 5120 }), } num_extra_tokens = 1 # class token def __init__(self, arch='base', img_size=224, patch_size=16, in_channels=3, out_indices=-1, drop_rate=0., drop_path_rate=0., qkv_bias=True, norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, out_type='cls_token', with_cls_token=True, use_layer_scale=True, interpolate_mode='bicubic', patch_cfg=dict(), layer_cfgs=dict(), init_cfg=None): super(VisionTransformer, self).__init__(init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = { 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' } assert isinstance(arch, dict) and essential_keys <= set(arch), \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.num_layers = self.arch_settings['num_layers'] self.img_size = to_2tuple(img_size) # Set patch embedding _patch_cfg = dict( in_channels=in_channels, input_size=img_size, embed_dims=self.embed_dims, conv_type='Conv2d', kernel_size=patch_size, stride=patch_size, ) _patch_cfg.update(patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) self.patch_resolution = self.patch_embed.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] # Set out type if out_type not in self.OUT_TYPES: raise ValueError(f'Unsupported `out_type` {out_type}, please ' f'choose from {self.OUT_TYPES}') self.out_type = out_type # Set cls token if with_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) elif out_type != 'cls_token': self.cls_token = None self.num_extra_tokens = 0 else: raise ValueError( 'with_cls_token must be True when `out_type="cls_token"`.') # Set position embedding self.interpolate_mode = interpolate_mode self.pos_embed = nn.Parameter( torch.zeros(1, num_patches, self.embed_dims)) self._register_load_state_dict_pre_hook(self._prepare_pos_embed) self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = self.num_layers + index assert 0 <= out_indices[i] <= self.num_layers, \ f'Invalid out_indices {index}' self.out_indices = out_indices # stochastic depth decay rule dpr = np.linspace(0, drop_path_rate, self.num_layers) self.layers = ModuleList() if isinstance(layer_cfgs, dict): layer_cfgs = [layer_cfgs] * self.num_layers for i in range(self.num_layers): _layer_cfg = dict( embed_dims=self.embed_dims, num_heads=self.arch_settings['num_heads'], feedforward_channels=self. arch_settings['feedforward_channels'], drop_rate=drop_rate, drop_path_rate=dpr[i], qkv_bias=qkv_bias, norm_cfg=norm_cfg, use_layer_scale=use_layer_scale) _layer_cfg.update(layer_cfgs[i]) self.layers.append(DeiT3TransformerEncoderLayer(**_layer_cfg)) self.final_norm = final_norm if final_norm: self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) def forward(self, x): B = x.shape[0] x, patch_resolution = self.patch_embed(x) x = x + resize_pos_embed( self.pos_embed, self.patch_resolution, patch_resolution, mode=self.interpolate_mode, num_extra_tokens=0) x = self.drop_after_pos(x) if self.cls_token is not None: # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1 and self.final_norm: x = self.ln1(x) if i in self.out_indices: outs.append(self._format_output(x, patch_resolution)) return tuple(outs) def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): name = prefix + 'pos_embed' if name not in state_dict.keys(): return ckpt_pos_embed_shape = state_dict[name].shape if self.pos_embed.shape != ckpt_pos_embed_shape: from mmengine.logging import MMLogger logger = MMLogger.get_current_instance() logger.info( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' f'to {self.pos_embed.shape}.') ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1]))) pos_embed_shape = self.patch_embed.init_out_size state_dict[name] = resize_pos_embed( state_dict[name], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode, num_extra_tokens=0, # The cls token adding is after pos_embed )
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.