Source code for mmpretrain.models.heads.levit_head
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmpretrain.models.heads import ClsHead
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer
class BatchNormLinear(BaseModule):
def __init__(self, in_channels, out_channels, norm_cfg=dict(type='BN1d')):
super(BatchNormLinear, self).__init__()
self.bn = build_norm_layer(norm_cfg, in_channels)
self.linear = nn.Linear(in_channels, out_channels)
@torch.no_grad()
def fuse(self):
w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5
b = self.bn.bias - self.bn.running_mean * \
self.bn.weight / (self.bn.running_var + self.bn.eps) ** 0.5
w = self.linear.weight * w[None, :]
b = (self.linear.weight @ b[:, None]).view(-1) + self.linear.bias
self.linear.weight.data.copy_(w)
self.linear.bias.data.copy_(b)
return self.linear
def forward(self, x):
x = self.bn(x)
x = self.linear(x)
return x
def fuse_parameters(module):
for child_name, child in module.named_children():
if hasattr(child, 'fuse'):
setattr(module, child_name, child.fuse())
else:
fuse_parameters(child)
[docs]@MODELS.register_module()
class LeViTClsHead(ClsHead):
def __init__(self,
num_classes=1000,
distillation=True,
in_channels=None,
deploy=False,
**kwargs):
super(LeViTClsHead, self).__init__(**kwargs)
self.num_classes = num_classes
self.distillation = distillation
self.deploy = deploy
self.head = BatchNormLinear(in_channels, num_classes)
if distillation:
self.head_dist = BatchNormLinear(in_channels, num_classes)
if self.deploy:
self.switch_to_deploy(self)
def switch_to_deploy(self):
if self.deploy:
return
fuse_parameters(self)
self.deploy = True
def forward(self, x):
x = self.pre_logits(x)
if self.distillation:
x = self.head(x), self.head_dist(x) # 2 16 384 -> 2 1000
if not self.training:
x = (x[0] + x[1]) / 2
else:
raise NotImplementedError("MMPretrain doesn't support "
'training in distillation mode.')
else:
x = self.head(x)
return x