Shortcuts

Source code for mmpretrain.models.heads.levit_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmpretrain.models.heads import ClsHead
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer


class BatchNormLinear(BaseModule):

    def __init__(self, in_channels, out_channels, norm_cfg=dict(type='BN1d')):
        super(BatchNormLinear, self).__init__()
        self.bn = build_norm_layer(norm_cfg, in_channels)
        self.linear = nn.Linear(in_channels, out_channels)

    @torch.no_grad()
    def fuse(self):
        w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5
        b = self.bn.bias - self.bn.running_mean * \
            self.bn.weight / (self.bn.running_var + self.bn.eps) ** 0.5
        w = self.linear.weight * w[None, :]
        b = (self.linear.weight @ b[:, None]).view(-1) + self.linear.bias

        self.linear.weight.data.copy_(w)
        self.linear.bias.data.copy_(b)
        return self.linear

    def forward(self, x):
        x = self.bn(x)
        x = self.linear(x)
        return x


def fuse_parameters(module):
    for child_name, child in module.named_children():
        if hasattr(child, 'fuse'):
            setattr(module, child_name, child.fuse())
        else:
            fuse_parameters(child)


[docs]@MODELS.register_module() class LeViTClsHead(ClsHead): def __init__(self, num_classes=1000, distillation=True, in_channels=None, deploy=False, **kwargs): super(LeViTClsHead, self).__init__(**kwargs) self.num_classes = num_classes self.distillation = distillation self.deploy = deploy self.head = BatchNormLinear(in_channels, num_classes) if distillation: self.head_dist = BatchNormLinear(in_channels, num_classes) if self.deploy: self.switch_to_deploy(self) def switch_to_deploy(self): if self.deploy: return fuse_parameters(self) self.deploy = True def forward(self, x): x = self.pre_logits(x) if self.distillation: x = self.head(x), self.head_dist(x) # 2 16 384 -> 2 1000 if not self.training: x = (x[0] + x[1]) / 2 else: raise NotImplementedError("MMPretrain doesn't support " 'training in distillation mode.') else: x = self.head(x) return x