SVT¶
- class mmpretrain.models.backbones.SVT(arch, in_channels=3, out_indices=(3,), qkv_bias=False, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_cfg={'type': 'LN'}, norm_after_stage=False, init_cfg=None)[source]¶
The backbone of Twins-SVT.
This backbone is the implementation of Twins: Revisiting the Design of Spatial Attention in Vision Transformers.
- Parameters:
SVT architecture, a str value in arch zoo or a detailed configuration dict with 8 keys, and the length of all the values in dict should be the same:
depths (List[int]): The number of encoder layers in each stage.
embed_dims (List[int]): Embedding dimension in each stage.
patch_sizes (List[int]): The patch sizes in each stage.
num_heads (List[int]): Numbers of attention head in each stage.
strides (List[int]): The strides in each stage.
mlp_ratios (List[int]): The ratios of mlp in each stage.
sr_ratios (List[int]): The ratios of GSA-encoder layers in each stage.
windiow_sizes (List[int]): The window sizes in LSA-encoder layers in each stage.
in_channels (int) – Number of input channels. Defaults to 3.
out_indices (tuple[int]) – Output from which stages. Defaults to (3, ).
qkv_bias (bool) – Enable bias for qkv if True. Defaults to False.
drop_rate (float) – Dropout rate. Defaults to 0.
attn_drop_rate (float) – Dropout ratio of attention weight. Defaults to 0.0
drop_path_rate (float) – Stochastic depth rate. Defaults to 0.2.
norm_cfg (dict) – Config dict for normalization layer. Defaults to
dict(type='LN')
.norm_after_stage (bool, List[bool]) – Add extra norm after each stage. Defaults to False.
init_cfg (dict, optional) – The Config for initialization. Defaults to None.
Examples
>>> from mmpretrain.models import SVT >>> import torch >>> svt_cfg = {'arch': "small", >>> 'norm_after_stage': [False, False, False, True]} >>> model = SVT(**svt_cfg) >>> x = torch.rand(1, 3, 224, 224) >>> outputs = model(x) >>> print(outputs[-1].shape) torch.Size([1, 512, 7, 7]) >>> svt_cfg["out_indices"] = (0, 1, 2, 3) >>> svt_cfg["norm_after_stage"] = [True, True, True, True] >>> model = SVT(**svt_cfg) >>> output = model(x) >>> for feat in output: >>> print(feat.shape) torch.Size([1, 64, 56, 56]) torch.Size([1, 128, 28, 28]) torch.Size([1, 320, 14, 14]) torch.Size([1, 512, 7, 7])