mmpretrain.models.necks.nonlinear_neck 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
[文档]@MODELS.register_module()
class NonLinearNeck(BaseModule):
"""The non-linear neck.
Structure: fc-bn-[relu-fc-bn] where the substructure in [] can be repeated.
For the default setting, the repeated time is 1.
The neck can be used in many algorithms, e.g., SimCLR, BYOL, SimSiam.
Args:
in_channels (int): Number of input channels.
hid_channels (int): Number of hidden channels.
out_channels (int): Number of output channels.
num_layers (int): Number of fc layers. Defaults to 2.
with_bias (bool): Whether to use bias in fc layers (except for the
last). Defaults to False.
with_last_bn (bool): Whether to add the last BN layer.
Defaults to True.
with_last_bn_affine (bool): Whether to have learnable affine parameters
in the last BN layer (set False for SimSiam). Defaults to True.
with_last_bias (bool): Whether to use bias in the last fc layer.
Defaults to False.
with_avg_pool (bool): Whether to apply the global average pooling
after backbone. Defaults to True.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='SyncBN').
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(
self,
in_channels: int,
hid_channels: int,
out_channels: int,
num_layers: int = 2,
with_bias: bool = False,
with_last_bn: bool = True,
with_last_bn_affine: bool = True,
with_last_bias: bool = False,
with_avg_pool: bool = True,
norm_cfg: dict = dict(type='SyncBN'),
init_cfg: Optional[Union[dict, List[dict]]] = [
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
) -> None:
super(NonLinearNeck, self).__init__(init_cfg)
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.relu = nn.ReLU(inplace=True)
self.fc0 = nn.Linear(in_channels, hid_channels, bias=with_bias)
self.bn0 = build_norm_layer(norm_cfg, hid_channels)[1]
self.fc_names = []
self.bn_names = []
for i in range(1, num_layers):
this_channels = out_channels if i == num_layers - 1 \
else hid_channels
if i != num_layers - 1:
self.add_module(
f'fc{i}',
nn.Linear(hid_channels, this_channels, bias=with_bias))
self.add_module(f'bn{i}',
build_norm_layer(norm_cfg, this_channels)[1])
self.bn_names.append(f'bn{i}')
else:
self.add_module(
f'fc{i}',
nn.Linear(
hid_channels, this_channels, bias=with_last_bias))
if with_last_bn:
self.add_module(
f'bn{i}',
build_norm_layer(
dict(**norm_cfg, affine=with_last_bn_affine),
this_channels)[1])
self.bn_names.append(f'bn{i}')
else:
self.bn_names.append(None)
self.fc_names.append(f'fc{i}')
[文档] def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
"""Forward function.
Args:
x (Tuple[torch.Tensor]): The feature map of backbone.
Returns:
Tuple[torch.Tensor]: The output features.
"""
assert len(x) == 1
x = x[0]
if self.with_avg_pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc0(x)
x = self.bn0(x)
for fc_name, bn_name in zip(self.fc_names, self.bn_names):
fc = getattr(self, fc_name)
x = self.relu(x)
x = fc(x)
if bn_name is not None:
bn = getattr(self, bn_name)
x = bn(x)
return (x, )