Shortcuts

mmpretrain.models.backbones.deit 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_

from mmpretrain.registry import MODELS
from .vision_transformer import VisionTransformer


[文档]@MODELS.register_module() class DistilledVisionTransformer(VisionTransformer): """Distilled Vision Transformer. A PyTorch implement of : `Training data-efficient image transformers & distillation through attention <https://arxiv.org/abs/2012.12877>`_ Args: arch (str | dict): Vision Transformer architecture. If use string, choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' and 'deit-base'. If use dict, it should have below keys: - **embed_dims** (int): The dimensions of embedding. - **num_layers** (int): The number of transformer encoder layers. - **num_heads** (int): The number of heads in attention modules. - **feedforward_channels** (int): The hidden dimensions in feedforward modules. Defaults to 'deit-base'. img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224. patch_size (int | tuple): The patch size in patch embedding. Defaults to 16. in_channels (int): The num of input channels. Defaults to 3. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. out_type (str): The type of output features. Please choose from - ``"cls_token"``: A tuple with the class token and the distillation token. The shapes of both tensor are (B, C). - ``"featmap"``: The feature map tensor from the patch tokens with shape (B, C, H, W). - ``"avg_featmap"``: The global averaged feature map tensor with shape (B, C). - ``"raw"``: The raw feature tensor includes patch tokens and class tokens with shape (B, L, C). Defaults to ``"cls_token"``. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ num_extra_tokens = 2 # class token and distillation token def __init__(self, arch='deit-base', *args, **kwargs): super(DistilledVisionTransformer, self).__init__( arch=arch, with_cls_token=True, *args, **kwargs, ) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) def forward(self, x): B = x.shape[0] x, patch_resolution = self.patch_embed(x) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) x = x + self.resize_pos_embed( self.pos_embed, self.patch_resolution, patch_resolution, mode=self.interpolate_mode, num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1 and self.final_norm: x = self.ln1(x) if i in self.out_indices: outs.append(self._format_output(x, patch_resolution)) return tuple(outs) def _format_output(self, x, hw): if self.out_type == 'cls_token': return x[:, 0], x[:, 1] return super()._format_output(x, hw) def init_weights(self): super(DistilledVisionTransformer, self).init_weights() if not (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): trunc_normal_(self.dist_token, std=0.02)
Read the Docs v: dev
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.