Source code for mmpretrain.models.necks.swav_neck

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union

import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule

from mmpretrain.registry import MODELS

[docs]@MODELS.register_module() class SwAVNeck(BaseModule): """The non-linear neck of SwAV: fc-bn-relu-fc-normalization. Args: in_channels (int): Number of input channels. hid_channels (int): Number of hidden channels. out_channels (int): Number of output channels. with_avg_pool (bool): Whether to apply the global average pooling after backbone. Defaults to True. with_l2norm (bool): whether to normalize the output after projection. Defaults to True. norm_cfg (dict): Dictionary to construct and config norm layer. Defaults to dict(type='SyncBN'). init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__( self, in_channels: int, hid_channels: int, out_channels: int, with_avg_pool: bool = True, with_l2norm: bool = True, norm_cfg: dict = dict(type='SyncBN'), init_cfg: Optional[Union[dict, List[dict]]] = [ dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ] ) -> None: super().__init__(init_cfg) self.with_avg_pool = with_avg_pool self.with_l2norm = with_l2norm if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if out_channels == 0: self.projection_neck = nn.Identity() elif hid_channels == 0: self.projection_neck = nn.Linear(in_channels, out_channels) else: self.norm = build_norm_layer(norm_cfg, hid_channels)[1] self.projection_neck = nn.Sequential( nn.Linear(in_channels, hid_channels), self.norm, nn.ReLU(inplace=True), nn.Linear(hid_channels, out_channels), )
[docs] def forward_projection(self, x: torch.Tensor) -> torch.Tensor: """Compute projection. Args: x (torch.Tensor): The feature vectors after pooling. Returns: torch.Tensor: The output features with projection or L2-norm. """ x = self.projection_neck(x) if self.with_l2norm: x = nn.functional.normalize(x, dim=1, p=2) return x
[docs] def forward(self, x: List[torch.Tensor]) -> torch.Tensor: """Forward function. Args: x (List[torch.Tensor]): list of feature maps, len(x) according to len(num_crops). Returns: torch.Tensor: The projection vectors. """ avg_out = [] for _x in x: _x = _x[0] if self.with_avg_pool: _out = self.avgpool(_x) avg_out.append(_out) feat_vec = # [sum(num_crops) * N, C] feat_vec = feat_vec.view(feat_vec.size(0), -1) output = self.forward_projection(feat_vec) return output
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.