Shortcuts

SEResNet

class mmpretrain.models.backbones.SEResNet(depth, se_ratio=16, **kwargs)[source]

SEResNet backbone.

Please refer to the paper for details.

Parameters:
  • depth (int) – Network depth, from {50, 101, 152}.

  • se_ratio (int) – Squeeze ratio in SELayer. Default: 16.

  • in_channels (int) – Number of input image channels. Default: 3.

  • stem_channels (int) – Output channels of the stem layer. Default: 64.

  • num_stages (int) – Stages of the network. Default: 4.

  • strides (Sequence[int]) – Strides of the first block of each stage. Default: (1, 2, 2, 2).

  • dilations (Sequence[int]) – Dilation of each stage. Default: (1, 1, 1, 1).

  • out_indices (Sequence[int]) – Output from which stages. If only one stage is specified, a single tensor (feature map) is returned, otherwise multiple stages are specified, a tuple of tensors will be returned. Default: (3, ).

  • style (str) – pytorch or caffe. If set to “pytorch”, the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer.

  • deep_stem (bool) – Replace 7x7 conv in input stem with 3 3x3 conv. Default: False.

  • avg_down (bool) – Use AvgPool instead of stride conv when downsampling in the bottleneck. Default: False.

  • frozen_stages (int) – Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Default: -1.

  • conv_cfg (dict | None) – The config dict for conv layers. Default: None.

  • norm_cfg (dict) – The config dict for norm layers.

  • norm_eval (bool) – Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False.

  • with_cp (bool) – Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.

  • zero_init_residual (bool) – Whether to use zero init for last norm layer in resblocks to let them behave as identity. Default: True.

Example

>>> from mmpretrain.models import SEResNet
>>> import torch
>>> self = SEResNet(depth=50)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 224, 224)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
...     print(tuple(level_out.shape))
(1, 64, 56, 56)
(1, 128, 28, 28)
(1, 256, 14, 14)
(1, 512, 7, 7)
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.