Shortcuts

Source code for mmpretrain.models.selfsup.beit

# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, List, Optional, Tuple, Union

import torch
from einops import rearrange
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_
from torch import nn

from mmpretrain.models.backbones import BEiTViT
from mmpretrain.models.utils import NormEMAVectorQuantizer, resize_pos_embed
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from .base import BaseSelfSupervisor


[docs]@MODELS.register_module() class VQKD(BaseModule): """Vector-Quantized Knowledge Distillation. The module only contains encoder and VectorQuantizer part Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py Args: encoder_config (dict): The config of encoder. decoder_config (dict, optional): The config of decoder. Currently, VQKD only support to build encoder. Defaults to None. num_embed (int): Number of embedding vectors in the codebook. Defaults to 8192. embed_dims (int) : The dimension of embedding vectors in the codebook. Defaults to 32. decay (float): The decay parameter of EMA. Defaults to 0.99. beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1. quantize_kmeans_init (bool): Whether to use k-means to initialize the VectorQuantizer. Defaults to True. init_cfg (dict or List[dict], optional): Initialization config dict. Defaults to None. """ # noqa: E501 def __init__(self, encoder_config: dict, decoder_config: Optional[dict] = None, num_embed: int = 8192, embed_dims: int = 32, decay: float = 0.99, beta: float = 1.0, quantize_kmeans_init: bool = True, init_cfg: Optional[dict] = None) -> None: super().__init__(init_cfg=init_cfg) self.encoder = BEiTViT(**encoder_config) if decoder_config is not None: self.decoder = BEiTViT(**decoder_config) self.quantize = NormEMAVectorQuantizer( num_embed=num_embed, embed_dims=embed_dims, beta=beta, decay=decay, kmeans_init=quantize_kmeans_init, ) # task layer self.encode_task_layer = nn.Sequential( nn.Linear(self.encoder.arch_settings['embed_dims'], self.encoder.arch_settings['embed_dims']), nn.Tanh(), nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims))
[docs] def get_tokens(self, x: torch.Tensor) -> dict: """Get tokens for beit pre-training.""" _, embed_ind, _ = self.encode(x) output = {} output['token'] = embed_ind.view(x.shape[0], -1) output['input_img'] = x return output
[docs] def encode( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Encode the input images and get corresponding results.""" encoder_features = self.encoder(x)[0] B, C, N1, N2 = encoder_features.shape encoder_features = encoder_features.permute(0, 2, 3, 1).reshape(B, N1 * N2, C) with torch.cuda.amp.autocast(enabled=False): to_quantizer_features = self.encode_task_layer( encoder_features.type_as(self.encode_task_layer[-1].weight)) N = to_quantizer_features.shape[1] h, w = int(math.sqrt(N)), int(math.sqrt(N)) to_quantizer_features = rearrange( to_quantizer_features, 'b (h w) c -> b c h w', h=h, w=w) # reshape for quantizer quantize, loss, embed_ind = self.quantize(to_quantizer_features) return quantize, embed_ind, loss
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """The forward function. Currently, only support to get tokens. """ return self.get_tokens(x)['token']
[docs]@MODELS.register_module() class BEiTPretrainViT(BEiTViT): """Vision Transformer for BEiT pre-training. Args: arch (str | dict): Vision Transformer architecture. If use string, choose from 'small', 'base' and 'large'. If use dict, it should have below keys: - **embed_dims** (int): The dimensions of embedding. - **num_layers** (int): The number of transformer encoder layers. - **num_heads** (int): The number of heads in attention modules. - **feedforward_channels** (int): The hidden dimensions in feedforward modules. Defaults to 'base'. img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224. patch_size (int | tuple): The patch size in patch embedding. Defaults to 16. in_channels (int): The num of input channels. Defaults to 3. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. out_type (str): The type of output features. Please choose from - ``"cls_token"``: The class token tensor with shape (B, C). - ``"featmap"``: The feature map tensor from the patch tokens with shape (B, C, H, W). - ``"avg_featmap"``: The global averaged feature map tensor with shape (B, C). - ``"raw"``: The raw feature tensor includes patch tokens and class tokens with shape (B, L, C). It only works without input mask. Defaults to ``"avg_featmap"``. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. use_abs_pos_emb (bool): Whether or not use absolute position embedding. Defaults to False. use_rel_pos_bias (bool): Whether or not use relative position bias. Defaults to False. use_shared_rel_pos_bias (bool): Whether or not use shared relative position bias. Defaults to True. layer_scale_init_value (float): The initialization value for the learnable scaling of attention and FFN. Defaults to 0.1. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, arch: str = 'base', img_size: int = 224, patch_size: int = 16, in_channels: int = 3, out_indices: int = -1, drop_rate: float = 0, drop_path_rate: float = 0, norm_cfg: dict = dict(type='LN', eps=1e-6), final_norm: bool = True, out_type: str = 'raw', frozen_stages: int = -1, use_abs_pos_emb: bool = False, use_rel_pos_bias: bool = False, use_shared_rel_pos_bias: bool = True, layer_scale_init_value: int = 0.1, interpolate_mode: str = 'bicubic', patch_cfg: dict = dict(padding=0), layer_cfgs: dict = dict(), init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__( arch=arch, img_size=img_size, patch_size=patch_size, in_channels=in_channels, out_indices=out_indices, drop_rate=drop_rate, drop_path_rate=drop_path_rate, norm_cfg=norm_cfg, final_norm=final_norm, out_type=out_type, with_cls_token=True, frozen_stages=frozen_stages, use_abs_pos_emb=use_abs_pos_emb, use_shared_rel_pos_bias=use_shared_rel_pos_bias, use_rel_pos_bias=use_rel_pos_bias, layer_scale_init_value=layer_scale_init_value, interpolate_mode=interpolate_mode, patch_cfg=patch_cfg, layer_cfgs=layer_cfgs, init_cfg=init_cfg) self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
[docs] def init_weights(self) -> None: """Initialize position embedding, patch embedding and cls token.""" super().init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress default init if use pretrained model. return trunc_normal_(self.cls_token, std=0.02) trunc_normal_(self.mask_token, std=0.02) self.rescale_init_weight()
[docs] def rescale_init_weight(self) -> None: """Rescale the initialized weights.""" def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.layers): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
[docs] def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor]: """The BEiT style forward function. The function supports two kind of forward behaviors. If the ``mask`` is not ``None``, the forward function will be executed as masked image modeling pre-training; if the ``mask`` is ``None``, the forward function will call ``super().forward()``, which extract features from images without mask. Args: x (torch.Tensor): Input images, which is of shape (B x C x H x W). mask (torch.Tensor, optional): Mask for input, which is of shape (B x patch_resolution[0] x patch_resolution[1]). Returns: Tuple[torch.Tensor]: Hidden features. """ if mask is None: return super().forward(x) else: x, patch_resolution = self.patch_embed(x) # replace the masked visual tokens by mask_token B, L, _ = x.shape mask_token = self.mask_token.expand(B, L, -1) w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) x = x * (1. - w) + mask_token * w # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: x = x + resize_pos_embed( self.pos_embed, self.patch_resolution, patch_resolution, mode=self.interpolate_mode, num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) self.shared_rel_pos_bias = self.rel_pos_bias().to( mask.device) if self.rel_pos_bias is not None else None outs = [] for i, layer in enumerate(self.layers): x = layer(x, rel_pos_bias=self.shared_rel_pos_bias) if i == len(self.layers) - 1 and self.final_norm: x = self.norm1(x) if i in self.out_indices: outs.append(x) return tuple(outs)
[docs]@MODELS.register_module() class BEiT(BaseSelfSupervisor): """BEiT v1/v2. Implementation of `BEiT: BERT Pre-Training of Image Transformers <https://arxiv.org/abs/2106.08254>`_ and `BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers <https://arxiv.org/abs/2208.06366>`_. """ def extract_feat(self, inputs: torch.Tensor): return self.backbone(inputs, mask=None)
[docs] def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (List[torch.Tensor]): The input images. data_samples (List[DataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ mask = torch.stack([data_sample.mask for data_sample in data_samples]) img_latent = self.backbone(inputs[0], mask) # inputs[1] is the target image with torch.no_grad(): target = self.target_generator(inputs[1]) target = target.detach() if self.with_neck: # BEiT v2 feats, feats_cls_pt = self.neck( img_latent, rel_pos_bias=self.backbone.shared_rel_pos_bias) loss = self.head.loss(feats, feats_cls_pt, target, mask) else: # BEiT v1 loss = self.head.loss(img_latent[0], target, mask) if isinstance(loss, torch.Tensor): losses = dict(loss=loss) return losses elif isinstance(loss, Tuple): # the loss_1 and loss_2 are general reconstruction loss (patch # feature vectors from last layer of backbone) and early state # reconstruction loss (patch feature vectors from intermediate # layer of backbone) loss_1, loss_2 = loss[0], loss[1] losses = dict() # the key with prefix 'loss', like loss_1 and loss_2, will be used # as the final criterion losses['loss_1'] = loss_1 losses['loss_2'] = loss_2 return losses
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.