Source code for mmpretrain.models.backbones.xcit
# Copyright (c) OpenMMLab. All rights reserved.
import math
from functools import partial
from typing import Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import ConvModule, DropPath
from mmcv.cnn.bricks.transformer import FFN
from mmengine.model import BaseModule, Sequential
from mmengine.model.weight_init import trunc_normal_
from mmengine.utils import digit_version
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer, to_2tuple
from .base_backbone import BaseBackbone
if digit_version(torch.__version__) < digit_version('1.8.0'):
floor_div = torch.floor_divide
else:
floor_div = partial(torch.div, rounding_mode='floor')
class ClassAttntion(BaseModule):
"""Class Attention Module.
A PyTorch implementation of Class Attention Module introduced by:
`Going deeper with Image Transformers <https://arxiv.org/abs/2103.17239>`_
taken from
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications to do CA
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads. Defaults to 8.
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
proj_drop (float): The drop out rate for linear output weights.
Defaults to 0.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
""" # noqa: E501
def __init__(self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
init_cfg=None):
super(ClassAttntion, self).__init__(init_cfg=init_cfg)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
# We only need to calculate query of cls token.
q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads,
C // self.num_heads).permute(
0, 2, 1, 3)
k = self.k(x).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
q = q * self.scale
v = self.v(x).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
x_cls = self.proj(x_cls)
x_cls = self.proj_drop(x_cls)
return x_cls
class PositionalEncodingFourier(BaseModule):
"""Positional Encoding using a fourier kernel.
A PyTorch implementation of Positional Encoding relying on
a fourier kernel introduced by:
`Attention is all you Need <https://arxiv.org/abs/1706.03762>`_
Based on the `official XCiT code
<https://github.com/facebookresearch/xcit/blob/master/xcit.py>`_
Args:
hidden_dim (int): The hidden feature dimension. Defaults to 32.
dim (int): The output feature dimension. Defaults to 768.
temperature (int): A control variable for position encoding.
Defaults to 10000.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
hidden_dim: int = 32,
dim: int = 768,
temperature: int = 10000,
init_cfg=None):
super(PositionalEncodingFourier, self).__init__(init_cfg=init_cfg)
self.token_projection = ConvModule(
in_channels=hidden_dim * 2,
out_channels=dim,
kernel_size=1,
conv_cfg=None,
norm_cfg=None,
act_cfg=None)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
self.eps = 1e-6
def forward(self, B: int, H: int, W: int):
device = self.token_projection.conv.weight.device
y_embed = torch.arange(
1, H + 1, device=device).unsqueeze(1).repeat(1, 1, W).float()
x_embed = torch.arange(1, W + 1, device=device).repeat(1, H, 1).float()
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
dim_t = torch.arange(self.hidden_dim, device=device).float()
dim_t = floor_div(dim_t, 2)
dim_t = self.temperature**(2 * dim_t / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
[pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()],
dim=4).flatten(3)
pos_y = torch.stack(
[pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()],
dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
class ConvPatchEmbed(BaseModule):
"""Patch Embedding using multiple convolution layers.
Args:
img_size (int, tuple): input image size.
Defaults to 224, means the size is 224*224.
patch_size (int): The patch size in conv patch embedding.
Defaults to 16.
in_channels (int): The input channels of this module.
Defaults to 3.
embed_dims (int): The feature dimension
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
img_size: Union[int, tuple] = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dims: int = 768,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(ConvPatchEmbed, self).__init__(init_cfg=init_cfg)
img_size = to_2tuple(img_size)
num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
conv = partial(
ConvModule,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
layer = []
if patch_size == 16:
layer.append(
conv(in_channels=in_channels, out_channels=embed_dims // 8))
layer.append(
conv(
in_channels=embed_dims // 8, out_channels=embed_dims // 4))
elif patch_size == 8:
layer.append(
conv(in_channels=in_channels, out_channels=embed_dims // 4))
else:
raise ValueError('For patch embedding, the patch size must be 16 '
f'or 8, but get patch size {self.patch_size}.')
layer.append(
conv(in_channels=embed_dims // 4, out_channels=embed_dims // 2))
layer.append(
conv(
in_channels=embed_dims // 2,
out_channels=embed_dims,
act_cfg=None,
))
self.proj = Sequential(*layer)
def forward(self, x: torch.Tensor):
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2) # (B, N, C)
return x, (Hp, Wp)
class ClassAttentionBlock(BaseModule):
"""Transformer block using Class Attention.
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (float): The hidden dimension ratio for FFN.
Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
drop (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): The initial value for layer scale.
Defaults to 1.
tokens_norm (bool): Whether to normalize all tokens or just the
cls_token in the CA. Defaults to False.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN', eps=1e-6)``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
drop=0.,
attn_drop=0.,
drop_path=0.,
layer_scale_init_value=1.,
tokens_norm=False,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(ClassAttentionBlock, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, dim)
self.attn = ClassAttntion(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, dim)
self.ffn = FFN(
embed_dims=dim,
feedforward_channels=int(dim * mlp_ratio),
act_cfg=act_cfg,
ffn_drop=drop,
)
if layer_scale_init_value > 0:
self.gamma1 = nn.Parameter(layer_scale_init_value *
torch.ones(dim))
self.gamma2 = nn.Parameter(layer_scale_init_value *
torch.ones(dim))
else:
self.gamma1, self.gamma2 = 1.0, 1.0
# See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501
self.tokens_norm = tokens_norm
def forward(self, x):
x_norm1 = self.norm1(x)
x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1)
x = x + self.drop_path(self.gamma1 * x_attn)
if self.tokens_norm:
x = self.norm2(x)
else:
x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
x_res = x
cls_token = x[:, 0:1]
cls_token = self.gamma2 * self.ffn(cls_token, identity=0)
x = torch.cat([cls_token, x[:, 1:]], dim=1)
x = x_res + self.drop_path(x)
return x
class LPI(BaseModule):
"""Local Patch Interaction module.
A PyTorch implementation of Local Patch Interaction module
as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers
<https://arxiv.org/abs/2106.096819>`_
Local Patch Interaction module that allows explicit communication between
tokens in 3x3 windows to augment the implicit communication performed by
the block diagonal scatter attention. Implemented using 2 layers of
separable 3x3 convolutions with GeLU and BatchNorm2d
Args:
in_features (int): The input channels.
out_features (int, optional): The output channels. Defaults to None.
kernel_size (int): The kernel_size in ConvModule. Defaults to 3.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_features: int,
out_features: Optional[int] = None,
kernel_size: int = 3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(LPI, self).__init__(init_cfg=init_cfg)
out_features = out_features or in_features
padding = kernel_size // 2
self.conv1 = ConvModule(
in_channels=in_features,
out_channels=in_features,
kernel_size=kernel_size,
padding=padding,
groups=in_features,
bias=True,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('conv', 'act', 'norm'))
self.conv2 = ConvModule(
in_channels=in_features,
out_channels=out_features,
kernel_size=kernel_size,
padding=padding,
groups=out_features,
norm_cfg=None,
act_cfg=None)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, C = x.shape
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.conv1(x)
x = self.conv2(x)
x = x.reshape(B, C, N).permute(0, 2, 1)
return x
class XCA(BaseModule):
r"""Cross-Covariance Attention module.
A PyTorch implementation of Cross-Covariance Attention module
as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers
<https://arxiv.org/abs/2106.096819>`_
In Cross-Covariance Attention (XCA), the channels are updated using a
weighted sum. The weights are obtained from the (softmax normalized)
Cross-covariance matrix :math:`(Q^T \cdot K \in d_h \times d_h)`
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads. Defaults to 8.
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
proj_drop (float): The drop out rate for linear output weights.
Defaults to 0.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
init_cfg=None):
super(XCA, self).__init__(init_cfg=init_cfg)
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
# (qkv, B, num_heads, channels per head, N)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 4, 1)
q, k, v = qkv.unbind(0)
# Paper section 3.2 l2-Normalization and temperature scaling
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# (B, num_heads, C', N) -> (B, N, num_heads, C') -> (B, N C)
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class XCABlock(BaseModule):
"""Transformer block using XCA.
Args:
dim (int): The feature dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (float): The hidden dimension ratio for FFNs.
Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to False.
drop (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): The initial value for layer scale.
Defaults to 1.
bn_norm_cfg (dict): Config dict for batchnorm in LPI and
ConvPatchEmbed. Defaults to ``dict(type='BN')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN', eps=1e-6)``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
"""
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
layer_scale_init_value: float = 1.,
bn_norm_cfg=dict(type='BN'),
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(XCABlock, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, dim)
self.attn = XCA(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm3 = build_norm_layer(norm_cfg, dim)
self.local_mp = LPI(
in_features=dim,
norm_cfg=bn_norm_cfg,
act_cfg=act_cfg,
)
self.norm2 = build_norm_layer(norm_cfg, dim)
self.ffn = FFN(
embed_dims=dim,
feedforward_channels=int(dim * mlp_ratio),
act_cfg=act_cfg,
ffn_drop=drop,
)
self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
def forward(self, x, H: int, W: int):
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
# NOTE official code has 3 then 2, so keeping it the same to be
# consistent with loaded weights See
# https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501
x = x + self.drop_path(
self.gamma3 * self.local_mp(self.norm3(x), H, W))
x = x + self.drop_path(
self.gamma2 * self.ffn(self.norm2(x), identity=0))
return x
[docs]@MODELS.register_module()
class XCiT(BaseBackbone):
"""XCiT backbone.
A PyTorch implementation of XCiT backbone introduced by:
`XCiT: Cross-Covariance Image Transformers
<https://arxiv.org/abs/2106.096819>`_
Args:
img_size (int, tuple): Input image size. Defaults to 224.
patch_size (int): Patch size. Defaults to 16.
in_channels (int): Number of input channels. Defaults to 3.
embed_dims (int): Embedding dimension. Defaults to 768.
depth (int): depth of vision transformer. Defaults to 12.
cls_attn_layers (int): Depth of Class attention layers.
Defaults to 2.
num_heads (int): Number of attention heads. Defaults to 12.
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
Defaults to 4.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
use_pos_embed (bool): Whether to use positional encoding.
Defaults to True.
layer_scale_init_value (float): The initial value for layer scale.
Defaults to 1.
tokens_norm (bool): Whether to normalize all tokens or just the
cls_token in the CA. Defaults to False.
out_indices (Sequence[int]): Output from which layers.
Defaults to (-1, ).
frozen_stages (int): Layers to be frozen (all param fixed), and 0
means to freeze the stem stage. Defaults to -1, which means
not freeze any parameters.
bn_norm_cfg (dict): Config dict for the batch norm layers in LPI and
ConvPatchEmbed. Defaults to ``dict(type='BN')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN', eps=1e-6)``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='GELU')``.
init_cfg (dict | list[dict], optional): Initialization config dict.
"""
def __init__(self,
img_size: Union[int, tuple] = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dims: int = 768,
depth: int = 12,
cls_attn_layers: int = 2,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
use_pos_embed: bool = True,
layer_scale_init_value: float = 1.,
tokens_norm: bool = False,
out_type: str = 'cls_token',
out_indices: Sequence[int] = (-1, ),
final_norm: bool = True,
frozen_stages: int = -1,
bn_norm_cfg=dict(type='BN'),
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
init_cfg=dict(type='TruncNormal', layer='Linear')):
super(XCiT, self).__init__(init_cfg=init_cfg)
img_size = to_2tuple(img_size)
if (img_size[0] % patch_size != 0) or (img_size[1] % patch_size != 0):
raise ValueError(f'`patch_size` ({patch_size}) should divide '
f'the image shape ({img_size}) evenly.')
self.embed_dims = embed_dims
assert out_type in ('raw', 'featmap', 'avg_featmap', 'cls_token')
self.out_type = out_type
self.patch_embed = ConvPatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims,
norm_cfg=bn_norm_cfg,
act_cfg=act_cfg,
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.use_pos_embed = use_pos_embed
if use_pos_embed:
self.pos_embed = PositionalEncodingFourier(dim=embed_dims)
self.pos_drop = nn.Dropout(p=drop_rate)
self.xca_layers = nn.ModuleList()
self.ca_layers = nn.ModuleList()
self.num_layers = depth + cls_attn_layers
for _ in range(depth):
self.xca_layers.append(
XCABlock(
dim=embed_dims,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=drop_path_rate,
bn_norm_cfg=bn_norm_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
layer_scale_init_value=layer_scale_init_value,
))
for _ in range(cls_attn_layers):
self.ca_layers.append(
ClassAttentionBlock(
dim=embed_dims,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
layer_scale_init_value=layer_scale_init_value,
tokens_norm=tokens_norm,
))
if final_norm:
self.norm = build_norm_layer(norm_cfg, embed_dims)
# Transform out_indices
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
out_indices = list(out_indices)
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}.'
self.out_indices = out_indices
if frozen_stages > self.num_layers + 1:
raise ValueError('frozen_stages must be less than '
f'{self.num_layers} but get {frozen_stages}')
self.frozen_stages = frozen_stages
def init_weights(self):
super().init_weights()
if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained':
return
trunc_normal_(self.cls_token, std=.02)
def _freeze_stages(self):
if self.frozen_stages < 0:
return
# freeze position embedding
if self.use_pos_embed:
self.pos_embed.eval()
for param in self.pos_embed.parameters():
param.requires_grad = False
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# set dropout to eval model
self.pos_drop.eval()
# freeze cls_token, only use in self.Clslayers
if self.frozen_stages > len(self.xca_layers):
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages):
if i <= len(self.xca_layers):
m = self.xca_layers[i - 1]
else:
m = self.ca_layers[i - len(self.xca_layers) - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm if all_stages are frozen
if self.frozen_stages == len(self.xca_layers) + len(self.ca_layers):
self.norm.eval()
for param in self.norm.parameters():
param.requires_grad = False
def forward(self, x):
outs = []
B = x.shape[0]
# x is (B, N, C). (Hp, Hw) is the patch resolution
x, (Hp, Wp) = self.patch_embed(x)
if self.use_pos_embed:
# (B, C, Hp, Wp) -> (B, C, N) -> (B, N, C)
pos_encoding = self.pos_embed(B, Hp, Wp)
x = x + pos_encoding.reshape(B, -1, x.size(1)).permute(0, 2, 1)
x = self.pos_drop(x)
for i, layer in enumerate(self.xca_layers):
x = layer(x, Hp, Wp)
if i in self.out_indices:
outs.append(self._format_output(x, (Hp, Wp), False))
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
for i, layer in enumerate(self.ca_layers):
x = layer(x)
if i == len(self.ca_layers) - 1:
x = self.norm(x)
if i + len(self.xca_layers) in self.out_indices:
outs.append(self._format_output(x, (Hp, Wp), True))
return tuple(outs)
def _format_output(self, x, hw, with_cls_token: bool):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
if not with_cls_token:
raise ValueError(
'Cannot output cls_token since there is no cls_token.')
return x[:, 0]
patch_token = x[:, 1:] if with_cls_token else x
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return patch_token.mean(dim=1)
def train(self, mode=True):
super().train(mode)
self._freeze_stages()