Shortcuts

Source code for mmpretrain.models.backbones.efficientformer

# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Optional, Sequence

import torch
import torch.nn as nn
from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
                             build_norm_layer)
from mmengine.model import BaseModule, ModuleList, Sequential

from mmpretrain.registry import MODELS
from ..utils import LayerScale
from .base_backbone import BaseBackbone
from .poolformer import Pooling


class AttentionWithBias(BaseModule):
    """Multi-head Attention Module with attention_bias.

    Args:
        embed_dims (int): The embedding dimension.
        num_heads (int): Parallel attention heads. Defaults to 8.
        key_dim (int): The dimension of q, k. Defaults to 32.
        attn_ratio (float): The dimension of v equals to
            ``key_dim * attn_ratio``. Defaults to 4.
        resolution (int): The height and width of attention_bias.
            Defaults to 7.
        init_cfg (dict, optional): The Config for initialization.
            Defaults to None.
    """

    def __init__(self,
                 embed_dims,
                 num_heads=8,
                 key_dim=32,
                 attn_ratio=4.,
                 resolution=7,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.num_heads = num_heads
        self.scale = key_dim**-0.5
        self.attn_ratio = attn_ratio
        self.key_dim = key_dim
        self.nh_kd = key_dim * num_heads
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * num_heads
        h = self.dh + self.nh_kd * 2
        self.qkv = nn.Linear(embed_dims, h)
        self.proj = nn.Linear(self.dh, embed_dims)

        points = list(itertools.product(range(resolution), range(resolution)))
        N = len(points)
        attention_offsets = {}
        idxs = []
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = nn.Parameter(
            torch.zeros(num_heads, len(attention_offsets)))
        self.register_buffer('attention_bias_idxs',
                             torch.LongTensor(idxs).view(N, N))

    @torch.no_grad()
    def train(self, mode=True):
        """change the mode of model."""
        super().train(mode)
        if mode and hasattr(self, 'ab'):
            del self.ab
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]

    def forward(self, x):
        """forward function.

        Args:
            x (tensor): input features with shape of (B, N, C)
        """
        B, N, _ = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1)

        attn = ((q @ k.transpose(-2, -1)) * self.scale +
                (self.attention_biases[:, self.attention_bias_idxs]
                 if self.training else self.ab))
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
        x = self.proj(x)
        return x


class Flat(nn.Module):
    """Flat the input from (B, C, H, W) to (B, H*W, C)."""

    def __init__(self, ):
        super().__init__()

    def forward(self, x: torch.Tensor):
        x = x.flatten(2).transpose(1, 2)
        return x


class LinearMlp(BaseModule):
    """Mlp implemented with linear.

    The shape of input and output tensor are (B, N, C).

    Args:
        in_features (int): Dimension of input features.
        hidden_features (int): Dimension of hidden features.
        out_features (int): Dimension of output features.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='BN')``.
        act_cfg (dict): The config dict for activation between pointwise
            convolution. Defaults to ``dict(type='GELU')``.
        drop (float): Dropout rate. Defaults to 0.0.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 act_cfg=dict(type='GELU'),
                 drop=0.,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = build_activation_layer(act_cfg)
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): input tensor with shape (B, N, C).

        Returns:
            torch.Tensor: output tensor with shape (B, N, C).
        """
        x = self.drop1(self.act(self.fc1(x)))
        x = self.drop2(self.fc2(x))
        return x


class ConvMlp(BaseModule):
    """Mlp implemented with 1*1 convolutions.

    Args:
        in_features (int): Dimension of input features.
        hidden_features (int): Dimension of hidden features.
        out_features (int): Dimension of output features.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='BN')``.
        act_cfg (dict): The config dict for activation between pointwise
            convolution. Defaults to ``dict(type='GELU')``.
        drop (float): Dropout rate. Defaults to 0.0.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='GELU'),
                 drop=0.,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.act = build_activation_layer(act_cfg)
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
        self.norm2 = build_norm_layer(norm_cfg, out_features)[1]

        self.drop = nn.Dropout(drop)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): input tensor with shape (B, C, H, W).

        Returns:
            torch.Tensor: output tensor with shape (B, C, H, W).
        """

        x = self.act(self.norm1(self.fc1(x)))
        x = self.drop(x)
        x = self.norm2(self.fc2(x))
        x = self.drop(x)
        return x


class Meta3D(BaseModule):
    """Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape
    (B, N, C)."""

    def __init__(self,
                 dim,
                 mlp_ratio=4.,
                 norm_cfg=dict(type='LN'),
                 act_cfg=dict(type='GELU'),
                 drop=0.,
                 drop_path=0.,
                 use_layer_scale=True,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.norm1 = build_norm_layer(norm_cfg, dim)[1]
        self.token_mixer = AttentionWithBias(dim)
        self.norm2 = build_norm_layer(norm_cfg, dim)[1]
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = LinearMlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_cfg=act_cfg,
            drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. \
            else nn.Identity()
        if use_layer_scale:
            self.ls1 = LayerScale(dim)
            self.ls2 = LayerScale(dim)
        else:
            self.ls1, self.ls2 = nn.Identity(), nn.Identity()

    def forward(self, x):
        x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x))))
        x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
        return x


class Meta4D(BaseModule):
    """Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape
    (B, C, H, W)."""

    def __init__(self,
                 dim,
                 pool_size=3,
                 mlp_ratio=4.,
                 act_cfg=dict(type='GELU'),
                 drop=0.,
                 drop_path=0.,
                 use_layer_scale=True,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)

        self.token_mixer = Pooling(pool_size=pool_size)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = ConvMlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_cfg=act_cfg,
            drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. \
            else nn.Identity()
        if use_layer_scale:
            self.ls1 = LayerScale(dim, data_format='channels_first')
            self.ls2 = LayerScale(dim, data_format='channels_first')
        else:
            self.ls1, self.ls2 = nn.Identity(), nn.Identity()

    def forward(self, x):
        x = x + self.drop_path(self.ls1(self.token_mixer(x)))
        x = x + self.drop_path(self.ls2(self.mlp(x)))
        return x


def basic_blocks(in_channels,
                 out_channels,
                 index,
                 layers,
                 pool_size=3,
                 mlp_ratio=4.,
                 act_cfg=dict(type='GELU'),
                 drop_rate=.0,
                 drop_path_rate=0.,
                 use_layer_scale=True,
                 vit_num=1,
                 has_downsamper=False):
    """generate EfficientFormer blocks for a stage."""
    blocks = []
    if has_downsamper:
        blocks.append(
            ConvModule(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=True,
                norm_cfg=dict(type='BN'),
                act_cfg=None))
    if index == 3 and vit_num == layers[index]:
        blocks.append(Flat())
    for block_idx in range(layers[index]):
        block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
            sum(layers) - 1)
        if index == 3 and layers[index] - block_idx <= vit_num:
            blocks.append(
                Meta3D(
                    out_channels,
                    mlp_ratio=mlp_ratio,
                    act_cfg=act_cfg,
                    drop=drop_rate,
                    drop_path=block_dpr,
                    use_layer_scale=use_layer_scale,
                ))
        else:
            blocks.append(
                Meta4D(
                    out_channels,
                    pool_size=pool_size,
                    act_cfg=act_cfg,
                    drop=drop_rate,
                    drop_path=block_dpr,
                    use_layer_scale=use_layer_scale))
            if index == 3 and layers[index] - block_idx - 1 == vit_num:
                blocks.append(Flat())
    blocks = nn.Sequential(*blocks)
    return blocks


[docs]@MODELS.register_module() class EfficientFormer(BaseBackbone): """EfficientFormer. A PyTorch implementation of EfficientFormer introduced by: `EfficientFormer: Vision Transformers at MobileNet Speed <https://arxiv.org/abs/2206.01191>`_ Modified from the `official repo <https://github.com/snap-research/EfficientFormer>`. Args: arch (str | dict): The model's architecture. If string, it should be one of architecture in ``EfficientFormer.arch_settings``. And if dict, it should include the following 4 keys: - layers (list[int]): Number of blocks at each stage. - embed_dims (list[int]): The number of channels at each stage. - downsamples (list[int]): Has downsample or not in the four stages. - vit_num (int): The num of vit blocks in the last stage. Defaults to 'l1'. in_channels (int): The num of input channels. Defaults to 3. pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3. mlp_ratios (int): The dimension ratio of multi-head attention mechanism in ``Meta4D`` blocks. Defaults to 3. reshape_last_feat (bool): Whether to reshape the feature map from (B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num`` in ``arch`` is not 0. Defaults to False. Usually set to True in downstream tasks. out_indices (Sequence[int]): Output from which stages. Defaults to -1. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. act_cfg (dict): The config dict for activation between pointwise convolution. Defaults to ``dict(type='GELU')``. drop_rate (float): Dropout rate. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0. use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer block. Defaults to True. init_cfg (dict, optional): Initialization config dict. Defaults to None. Example: >>> from mmpretrain.models import EfficientFormer >>> import torch >>> inputs = torch.rand((1, 3, 224, 224)) >>> # build EfficientFormer backbone for classification task >>> model = EfficientFormer(arch="l1") >>> model.eval() >>> level_outputs = model(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 448, 49) >>> # build EfficientFormer backbone for downstream task >>> model = EfficientFormer( >>> arch="l3", >>> out_indices=(0, 1, 2, 3), >>> reshape_last_feat=True) >>> model.eval() >>> level_outputs = model(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 64, 56, 56) (1, 128, 28, 28) (1, 320, 14, 14) (1, 512, 7, 7) """ # noqa: E501 # --layers: [x,x,x,x], numbers of layers for the four stages # --embed_dims: [x,x,x,x], embedding dims for the four stages # --downsamples: [x,x,x,x], has downsample or not in the four stages # --vit_num:(int), the num of vit blocks in the last stage arch_settings = { 'l1': { 'layers': [3, 2, 6, 4], 'embed_dims': [48, 96, 224, 448], 'downsamples': [False, True, True, True], 'vit_num': 1, }, 'l3': { 'layers': [4, 4, 12, 6], 'embed_dims': [64, 128, 320, 512], 'downsamples': [False, True, True, True], 'vit_num': 4, }, 'l7': { 'layers': [6, 6, 18, 8], 'embed_dims': [96, 192, 384, 768], 'downsamples': [False, True, True, True], 'vit_num': 8, }, } def __init__(self, arch='l1', in_channels=3, pool_size=3, mlp_ratios=4, reshape_last_feat=False, out_indices=-1, frozen_stages=-1, act_cfg=dict(type='GELU'), drop_rate=0., drop_path_rate=0., use_layer_scale=True, init_cfg=None): super().__init__(init_cfg=init_cfg) self.num_extra_tokens = 0 # no cls_token, no dist_token if isinstance(arch, str): assert arch in self.arch_settings, \ f'Unavailable arch, please choose from ' \ f'({set(self.arch_settings)}) or pass a dict.' arch = self.arch_settings[arch] elif isinstance(arch, dict): default_keys = set(self.arch_settings['l1'].keys()) assert set(arch.keys()) == default_keys, \ f'The arch dict must have {default_keys}, ' \ f'but got {list(arch.keys())}.' self.layers = arch['layers'] self.embed_dims = arch['embed_dims'] self.downsamples = arch['downsamples'] assert isinstance(self.layers, list) and isinstance( self.embed_dims, list) and isinstance(self.downsamples, list) assert len(self.layers) == len(self.embed_dims) == len( self.downsamples) self.vit_num = arch['vit_num'] self.reshape_last_feat = reshape_last_feat assert self.vit_num >= 0, "'vit_num' must be an integer " \ 'greater than or equal to 0.' assert self.vit_num <= self.layers[-1], ( "'vit_num' must be an integer smaller than layer number") self._make_stem(in_channels, self.embed_dims[0]) # set the main block in network network = [] for i in range(len(self.layers)): if i != 0: in_channels = self.embed_dims[i - 1] else: in_channels = self.embed_dims[i] out_channels = self.embed_dims[i] stage = basic_blocks( in_channels, out_channels, i, self.layers, pool_size=pool_size, mlp_ratio=mlp_ratios, act_cfg=act_cfg, drop_rate=drop_rate, drop_path_rate=drop_path_rate, vit_num=self.vit_num, use_layer_scale=use_layer_scale, has_downsamper=self.downsamples[i]) network.append(stage) self.network = ModuleList(network) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = 4 + index assert out_indices[i] >= 0, f'Invalid out_indices {index}' self.out_indices = out_indices for i_layer in self.out_indices: if not self.reshape_last_feat and \ i_layer == 3 and self.vit_num > 0: layer = build_norm_layer( dict(type='LN'), self.embed_dims[i_layer])[1] else: # use GN with 1 group as channel-first LN2D layer = build_norm_layer( dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1] layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self.frozen_stages = frozen_stages self._freeze_stages() def _make_stem(self, in_channels: int, stem_channels: int): """make 2-ConvBNReLu stem layer.""" self.patch_embed = Sequential( ConvModule( in_channels, stem_channels // 2, kernel_size=3, stride=2, padding=1, bias=True, conv_cfg=None, norm_cfg=dict(type='BN'), inplace=True), ConvModule( stem_channels // 2, stem_channels, kernel_size=3, stride=2, padding=1, bias=True, conv_cfg=None, norm_cfg=dict(type='BN'), inplace=True)) def forward_tokens(self, x): outs = [] for idx, block in enumerate(self.network): if idx == len(self.network) - 1: N, _, H, W = x.shape if self.downsamples[idx]: H, W = H // 2, W // 2 x = block(x) if idx in self.out_indices: norm_layer = getattr(self, f'norm{idx}') if idx == len(self.network) - 1 and x.dim() == 3: # when ``vit-num`` > 0 and in the last stage, # if `self.reshape_last_feat`` is True, reshape the # features to `BCHW` format before the final normalization. # if `self.reshape_last_feat`` is False, do # normalization directly and permute the features to `BCN`. if self.reshape_last_feat: x = x.permute((0, 2, 1)).reshape(N, -1, H, W) x_out = norm_layer(x) else: x_out = norm_layer(x).permute((0, 2, 1)) else: x_out = norm_layer(x) outs.append(x_out.contiguous()) return tuple(outs) def forward(self, x): # input embedding x = self.patch_embed(x) # through stages x = self.forward_tokens(x) return x def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False for i in range(self.frozen_stages): # Include both block and downsample layer. module = self.network[i] module.eval() for param in module.parameters(): param.requires_grad = False if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') norm_layer.eval() for param in norm_layer.parameters(): param.requires_grad = False def train(self, mode=True): super(EfficientFormer, self).train(mode) self._freeze_stages()
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.