Shortcuts

Source code for mmpretrain.models.utils.inverted_residual

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule

from .se_layer import SELayer


[docs]class InvertedResidual(BaseModule): """Inverted Residual Block. Args: in_channels (int): The input channels of this module. out_channels (int): The output channels of this module. mid_channels (int): The input channels of the depthwise convolution. kernel_size (int): The kernel size of the depthwise convolution. Defaults to 3. stride (int): The stride of the depthwise convolution. Defaults to 1. se_cfg (dict, optional): Config dict for se layer. Defaults to None, which means no se layer. conv_cfg (dict): Config dict for convolution layer. Defaults to None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='BN')``. act_cfg (dict): Config dict for activation layer. Defaults to ``dict(type='ReLU')``. drop_path_rate (float): stochastic depth rate. Defaults to 0. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Defaults to False. init_cfg (dict | list[dict], optional): Initialization config dict. """ def __init__(self, in_channels, out_channels, mid_channels, kernel_size=3, stride=1, se_cfg=None, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), drop_path_rate=0., with_cp=False, init_cfg=None): super(InvertedResidual, self).__init__(init_cfg) self.with_res_shortcut = (stride == 1 and in_channels == out_channels) assert stride in [1, 2] self.with_cp = with_cp self.drop_path = DropPath( drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.with_se = se_cfg is not None self.with_expand_conv = (mid_channels != in_channels) if self.with_se: assert isinstance(se_cfg, dict) if self.with_expand_conv: self.expand_conv = ConvModule( in_channels=in_channels, out_channels=mid_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.depthwise_conv = ConvModule( in_channels=mid_channels, out_channels=mid_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=mid_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) if self.with_se: self.se = SELayer(**se_cfg) self.linear_conv = ConvModule( in_channels=mid_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None)
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor. """ def _inner_forward(x): out = x if self.with_expand_conv: out = self.expand_conv(out) out = self.depthwise_conv(out) if self.with_se: out = self.se(out) out = self.linear_conv(out) if self.with_res_shortcut: return x + self.drop_path(out) else: return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.