- class mmpretrain.models.backbones.DistilledVisionTransformer(arch='deit-base', *args, **kwargs)¶
Distilled Vision Transformer.
A PyTorch implement of : Training data-efficient image transformers & distillation through attention
Vision Transformer architecture. If use string, choose from ‘small’, ‘base’, ‘large’, ‘deit-tiny’, ‘deit-small’ and ‘deit-base’. If use dict, it should have below keys:
embed_dims (int): The dimensions of embedding.
num_layers (int): The number of transformer encoder layers.
num_heads (int): The number of heads in attention modules.
feedforward_channels (int): The hidden dimensions in feedforward modules.
Defaults to ‘deit-base’.
in_channels (int) – The num of input channels. Defaults to 3.
out_indices (Sequence | int) – Output from which stages. Defaults to -1, means the last stage.
drop_rate (float) – Probability of an element to be zeroed. Defaults to 0.
drop_path_rate (float) – stochastic depth rate. Defaults to 0.
qkv_bias (bool) – Whether to add bias for qkv in attention modules. Defaults to True.
norm_cfg (dict) – Config dict for normalization layer. Defaults to
final_norm (bool) – Whether to add a additional layer to normalize final feature map. Defaults to True.
out_type (str) –
The type of output features. Please choose from
"cls_token": A tuple with the class token and the distillation token. The shapes of both tensor are (B, C).
"featmap": The feature map tensor from the patch tokens with shape (B, C, H, W).
"avg_featmap": The global averaged feature map tensor with shape (B, C).
"raw": The raw feature tensor includes patch tokens and class tokens with shape (B, L, C).
interpolate_mode (str) – Select the interpolate mode for position embeding vector resize. Defaults to “bicubic”.
patch_cfg (dict) – Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict) – Configs of each transformer layer in encoder. Defaults to an empty dict.
init_cfg (dict, optional) – Initialization config dict. Defaults to None.