Shortcuts

Source code for mmpretrain.models.selfsup.cae

# Copyright (c) OpenMMLab. All rights reserved.
# Part of code is modified from BEiT
# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py
import math
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_

from mmpretrain.models.backbones import BEiTViT
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import build_2d_sincos_position_embedding
from .base import BaseSelfSupervisor


class Conv2d(nn.Module):
    """Rewrite Conv2d module according to DALL-E code."""

    def __init__(self,
                 n_in: int,
                 n_out: int,
                 kw: int,
                 use_float16: bool = True,
                 device: torch.device = torch.device('cpu'),
                 requires_grad: bool = False) -> None:
        super().__init__()

        w = torch.empty((n_out, n_in, kw, kw),
                        dtype=torch.float32,
                        device=device,
                        requires_grad=requires_grad)
        w.normal_(std=1 / math.sqrt(n_in * kw**2))

        b = torch.zeros((n_out, ),
                        dtype=torch.float32,
                        device=device,
                        requires_grad=requires_grad)
        self.kw = kw
        self.w, self.b = nn.Parameter(w), nn.Parameter(b)
        self.use_float16 = use_float16

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_float16 and 'cuda' in self.w.device.type:
            if x.dtype != torch.float16:
                x = x.half()

            w, b = self.w.half(), self.b.half()
        else:
            if x.dtype != torch.float32:
                x = x.float()

            w, b = self.w, self.b

        return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)


class EncoderBlock(nn.Module):
    """Rewrite EncoderBlock module according to DALL-E code."""

    def __init__(self,
                 n_in: int,
                 n_out: int,
                 n_layers: int,
                 device: torch.device = None,
                 requires_grad: bool = False) -> None:
        super().__init__()
        self.n_hid = n_out // 4
        self.post_gain = 1 / (n_layers**2)

        make_conv = partial(Conv2d, device=device, requires_grad=requires_grad)
        self.id_path = make_conv(n_in, n_out,
                                 1) if n_in != n_out else nn.Identity()
        self.res_path = nn.Sequential(
            OrderedDict([
                ('relu_1', nn.ReLU()),
                ('conv_1', make_conv(n_in, self.n_hid, 3)),
                ('relu_2', nn.ReLU()),
                ('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
                ('relu_3', nn.ReLU()),
                ('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
                ('relu_4', nn.ReLU()),
                ('conv_4', make_conv(self.n_hid, n_out, 1)),
            ]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.id_path(x) + self.post_gain * self.res_path(x)


[docs]@MODELS.register_module(name='DALL-E') class DALLEEncoder(BaseModule): """DALL-E Encoder for feature extraction. Args: group_count (int): Number of groups in DALL-E encoder. Defaults to 4. n_hid (int): Dimension of hidden layers. Defaults to 256. n_blk_per_group (int): Number of blocks per group. Defaults to 2. input_channels: (int): The channels of input images. Defaults to 3. vocab_size (int): Vocabulary size, indicating the number of classes. Defaults to 8192. device (torch.device): Device of parameters. Defaults to ``torch.device('cpu')``. requires_grad (bool): Require gradient or not. Defaults to False. init_cfg (Union[List[dict], dict], optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, group_count: int = 4, n_hid: int = 256, n_blk_per_group: int = 2, input_channels: int = 3, vocab_size: int = 8192, device: torch.device = torch.device('cpu'), requires_grad: bool = False, init_cfg: Union[dict, List[dict], None] = None): super().__init__(init_cfg=init_cfg) self.input_channels = input_channels blk_range = range(n_blk_per_group) n_layers = group_count * n_blk_per_group make_conv = partial(Conv2d, device=device, requires_grad=requires_grad) make_blk = partial( EncoderBlock, n_layers=n_layers, device=device, requires_grad=requires_grad) self.blocks = nn.Sequential( OrderedDict([ ('input', make_conv(input_channels, 1 * n_hid, 7)), ('group_1', nn.Sequential( OrderedDict([ *[(f'block_{i + 1}', make_blk(1 * n_hid, 1 * n_hid)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_2', nn.Sequential( OrderedDict([ *[(f'block_{i + 1}', make_blk(1 * n_hid if i == 0 else 2 * n_hid, 2 * n_hid)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_3', nn.Sequential( OrderedDict([ *[(f'block_{i + 1}', make_blk(2 * n_hid if i == 0 else 4 * n_hid, 4 * n_hid)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_4', nn.Sequential( OrderedDict([ *[(f'block_{i + 1}', make_blk(4 * n_hid if i == 0 else 8 * n_hid, 8 * n_hid)) for i in blk_range], ]))), ('output', nn.Sequential( OrderedDict([ ('relu', nn.ReLU()), ('conv', make_conv( 8 * n_hid, vocab_size, 1, use_float16=False)), ]))), ]))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function of DALL-E encoder. Args: x (torch.Tensor): The input images with shape (B, C, H, W). Returns: torch.Tensor: The output with shape (B, vocab_size, h, w). """ x = x.float() if len(x.shape) != 4: raise ValueError(f'input shape {x.shape} is not 4d') if x.shape[1] != self.input_channels: raise ValueError(f'input has {x.shape[1]} channels but model \ built for {self.input_channels}') if x.dtype != torch.float32: raise ValueError('input must have dtype torch.float32') return self.blocks(x)
[docs]@MODELS.register_module() class CAEPretrainViT(BEiTViT): """Vision Transformer for CAE pre-training and the implementation is based on BEiTViT. Args: arch (str | dict): Vision Transformer architecture. Default: 'b' img_size (int | tuple): Input image size patch_size (int | tuple): The patch size out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. bias (bool | str): The option to add leanable bias for q, k, v. If bias is True, it will add leanable bias. If bias is 'qv_bias', it will only add leanable bias for q, v. If bias is False, it will not add bias for q, k, v. Default to 'qv_bias'. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. out_type (str): The type of output features. Please choose from - ``"cls_token"``: The class token tensor with shape (B, C). - ``"featmap"``: The feature map tensor from the patch tokens with shape (B, C, H, W). - ``"avg_featmap"``: The global averaged feature map tensor with shape (B, C). - ``"raw"``: The raw feature tensor includes patch tokens and class tokens with shape (B, L, C). It only works without input mask. Defaults to ``"avg_featmap"``. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". layer_scale_init_value (float, optional): The init value of gamma in BEiTTransformerEncoderLayer. patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__( self, arch: str = 'b', img_size: int = 224, patch_size: int = 16, in_channels: int = 3, out_indices: int = -1, drop_rate: float = 0, drop_path_rate: float = 0, bias: bool = 'qv_bias', norm_cfg: dict = dict(type='LN', eps=1e-6), final_norm: bool = True, out_type: str = 'raw', frozen_stages: int = -1, use_abs_pos_emb: bool = True, use_rel_pos_bias: bool = False, use_shared_rel_pos_bias: bool = False, layer_scale_init_value: float = None, interpolate_mode: str = 'bicubic', patch_cfg: dict = dict(), layer_cfgs: dict = dict(), init_cfg: dict = [ dict(type='Constant', val=1, layer=['LayerNorm']), dict(type='TruncNormal', std=0.02, layer=['Conv2d']), dict(type='Xavier', distribution='uniform', layer=['Linear']) ] ) -> None: super().__init__( arch=arch, img_size=img_size, patch_size=patch_size, in_channels=in_channels, out_indices=out_indices, drop_rate=drop_rate, drop_path_rate=drop_path_rate, bias=bias, norm_cfg=norm_cfg, final_norm=final_norm, out_type=out_type, with_cls_token=True, frozen_stages=frozen_stages, use_abs_pos_emb=use_abs_pos_emb, use_rel_pos_bias=use_rel_pos_bias, use_shared_rel_pos_bias=use_shared_rel_pos_bias, layer_scale_init_value=layer_scale_init_value, interpolate_mode=interpolate_mode, patch_cfg=patch_cfg, layer_cfgs=layer_cfgs, init_cfg=init_cfg) self.pos_embed.requires_grad = False self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
[docs] def init_weights(self) -> None: """Initialize position embedding, patch embedding and cls token.""" super().init_weights() if not (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # initialize position embedding in backbone pos_embed = build_2d_sincos_position_embedding( int(self.num_patches**.5), self.pos_embed.shape[-1], cls_token=True) self.pos_embed.data.copy_(pos_embed.float()) trunc_normal_(self.cls_token, std=.02)
[docs] def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: """Generate features for masked images. This function generates mask images and get the hidden features for visible patches. The function supports two kind of forward behaviors. If the ``mask`` is not ``None``, the forward function will be executed as masked image modeling pre-training; if the ``mask`` is ``None``, the forward function will call ``super().forward()``, which extract features from images without mask. Args: x (torch.Tensor): Input images, which is of shape B x C x H x W. mask (torch.Tensor, optional): Mask for input, which is of shape B x L. Returns: torch.Tensor: hidden features. """ if mask is None: return super().forward(x) else: x, _ = self.patch_embed(x) batch_size, _, dim = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # NOTE: unmasked embeddings x_unmasked = x[~mask].reshape(batch_size, -1, dim) x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1) pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1, dim) pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape( batch_size, -1, dim) pos_embed_unmasked = torch.cat( (pos_embed[:, :1], pos_embed_unmasked), dim=1) x_unmasked = x_unmasked + pos_embed_unmasked x_unmasked = self.drop_after_pos(x_unmasked) for i, layer in enumerate(self.layers): x_unmasked = layer(x=x_unmasked, rel_pos_bias=None) if i == len(self.layers) - 1 and self.final_norm: x_unmasked = self.norm1(x_unmasked) return x_unmasked
[docs]@MODELS.register_module() class CAE(BaseSelfSupervisor): """CAE. Implementation of `Context Autoencoder for Self-Supervised Representation Learning <https://arxiv.org/abs/2202.03026>`_. Args: backbone (dict): Config dict for module of backbone. neck (dict): Config dict for module of neck. head (dict): Config dict for module of head functions. target_generator: (dict, optional): The target_generator module to generate targets for self-supervised learning optimization, such as HOG, extracted features from other modules(DALL-E, CLIP), etc. base_momentum (float): The base momentum coefficient for the target network. Defaults to 0.0. data_preprocessor (dict, optional): The config for preprocessing input data. If None or no specified type, it will use "SelfSupDataPreprocessor" as type. See :class:`SelfSupDataPreprocessor` for more details. Defaults to None. init_cfg (Union[List[dict], dict], optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, backbone: dict, neck: dict, head: dict, target_generator: Optional[dict] = None, base_momentum: float = 0.0, data_preprocessor: Optional[dict] = None, init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__( backbone=backbone, neck=neck, head=head, target_generator=target_generator, data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.momentum = base_momentum self.teacher = MODELS.build(backbone)
[docs] def init_weights(self) -> None: """Initialize weights.""" super().init_weights() # init the weights of teacher with those of backbone for param_backbone, param_teacher in zip(self.backbone.parameters(), self.teacher.parameters()): param_teacher.detach() param_teacher.data.copy_(param_backbone.data) param_teacher.requires_grad = False
[docs] def momentum_update(self) -> None: """Momentum update of the teacher network.""" for param_bacbone, param_teacher in zip(self.backbone.parameters(), self.teacher.parameters()): param_teacher.data = param_teacher.data * self.momentum + \ param_bacbone.data * (1. - self.momentum)
def extract_feat(self, inputs: torch.Tensor): return self.backbone(inputs, mask=None)
[docs] def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (List[torch.Tensor]): The input images. data_samples (List[DataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ mask = torch.stack([data_sample.mask for data_sample in data_samples]) mask = mask.flatten(1).to(torch.bool) unmasked = self.backbone(inputs[0], mask) # get the latent prediction for the masked patches with torch.no_grad(): # inputs[0] is the prediction image latent_target = self.teacher(inputs[0], ~mask) latent_target = latent_target[:, 1:, :] self.momentum_update() pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1) pos_embed_masked = pos_embed[:, 1:][mask].reshape(inputs[0].shape[0], -1, pos_embed.shape[-1]) pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape( inputs[0].shape[0], -1, pos_embed.shape[-1]) # input the unmasked tokens and masked tokens to the decoder logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked, pos_embed_unmasked) logits = logits.view(-1, logits.shape[-1]) # inputs[1] is the target image logits_target = self.target_generator(inputs[1]) loss_main, loss_align = self.head.loss(logits, logits_target, latent_pred, latent_target, mask) losses = dict() losses['loss'] = loss_main + loss_align losses['main'] = loss_main losses['align'] = loss_align return losses
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.