# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_

from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils.norm import build_norm_layer
from ..utils.sparse_modules import SparseHelper
from .base import BaseSelfSupervisor

[文档]@MODELS.register_module() class SparK(BaseSelfSupervisor): """Implementation of SparK. Implementation of `Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling <>`_. Modified from """ def __init__( self, backbone: dict, neck: dict, head: dict, pretrained: Optional[str] = None, data_preprocessor: Optional[dict] = None, input_size: int = 224, downsample_raito: int = 32, mask_ratio: float = 0.6, enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'), enc_dec_norm_dim: int = 2048, init_cfg: Optional[dict] = None, ) -> None: super().__init__( backbone=backbone, neck=neck, head=head, pretrained=pretrained, data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.input_size = input_size self.downsample_raito = downsample_raito feature_map_size = input_size // downsample_raito self.feature_map_size = feature_map_size self.mask_ratio = mask_ratio self.len_keep = round(feature_map_size * feature_map_size * (1 - mask_ratio)) self.enc_dec_norm_cfg = enc_dec_norm_cfg self.enc_dec_norms = nn.ModuleList() self.enc_dec_projectors = nn.ModuleList() self.mask_tokens = nn.ParameterList() proj_out_dim = self.neck.feature_dim for i in range(len(self.backbone.out_indices)): enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg, enc_dec_norm_dim) self.enc_dec_norms.append(enc_dec_norm) kernel_size = 1 if i <= 0 else 3 proj_layer = nn.Conv2d( enc_dec_norm_dim, proj_out_dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, bias=True) if i == 0 and enc_dec_norm_dim == proj_out_dim: proj_layer = nn.Identity() self.enc_dec_projectors.append(proj_layer) mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1)) trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02) self.mask_tokens.append(mask_token) enc_dec_norm_dim //= 2 proj_out_dim //= 2 feature_map_size *= 2
[文档] def mask(self, shape: torch.Size, device: Union[torch.device, str], generator: Optional[torch.Generator] = None): """Mask generation. Args: shape (torch.Size): The shape of the input images. device (Union[torch.device, str]): The device of the tensor. generator (torch.Generator, optional): Generator for random functions. Defaults to None Returns: torch.Tensor: The generated mask. """ B, C, H, W = shape f = self.feature_map_size idx = torch.rand(B, f * f, generator=generator).argsort(dim=1) idx = idx[:, :self.len_keep].to(device) # (B, len_keep) return torch.zeros( B, f * f, dtype=torch.bool, device=device).scatter_( dim=1, index=idx, value=True).view(B, 1, f, f)
[文档] def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (List[torch.Tensor]): The input images. data_samples (List[DataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ # active mask of feature map, (B, 1, f, f) active_mask_feature_map = self.mask(inputs.shape, inputs.device) SparseHelper._cur_active = active_mask_feature_map # active mask of original input, (B, 1, H, W) active_mask_origin = active_mask_feature_map.repeat_interleave( self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) masked_img = inputs * active_mask_origin # get hierarchical encoded sparse features in a list # containing four feature maps feature_maps = self.backbone(masked_img) # from the smallest feature map to the largest feature_maps = list(feature_maps) feature_maps.reverse() cur_active = active_mask_feature_map feature_maps_to_dec = [] for i, feature_map in enumerate(feature_maps): if feature_map is not None: # fill in empty positions with [mask] embeddings feature_map = self.enc_dec_norms[i](feature_map) mask_token = self.mask_tokens[i].expand_as(feature_map) feature_map = torch.where( cur_active.expand_as(feature_map), feature_map, feature_map = self.enc_dec_projectors[i](feature_map) feature_maps_to_dec.append(feature_map) # dilate the mask map cur_active = cur_active.repeat_interleave( 2, dim=2).repeat_interleave( 2, dim=3) # decode and reconstruct rec_img = self.neck(feature_maps_to_dec) # compute loss loss = self.head(rec_img, inputs, active_mask_feature_map) losses = dict(loss=loss) return losses
