Shortcuts

MILANViT

class mmpretrain.models.selfsup.MILANViT(arch='b', img_size=224, patch_size=16, out_indices=-1, drop_rate=0, drop_path_rate=0, norm_cfg={'eps': 1e-06, 'type': 'LN'}, final_norm=True, out_type='raw', interpolate_mode='bicubic', patch_cfg={}, layer_cfgs={}, mask_ratio=0.75, init_cfg=None)[source]

Vision Transformer for MILAN pre-training.

Implementation of the encoder for MILAN: Masked Image Pretraining on Language Assisted Representation.

This module inherits from MAEViT and only overrides the forward function and replace random masking with attention masking.

attention_masking(x, mask_ratio, importance)[source]

Generate attention mask for MILAN.

This is what is different from MAEViT, which uses random masking. Attention masking generates attention mask for MILAN, according to importance. The higher the importance, the more likely the patch is kept.

Parameters:
  • x (torch.Tensor) – Input images, which is of shape B x L x C.

  • mask_ratio (float) – The ratio of patches to be masked.

  • importance (torch.Tensor) – Importance of each patch, which is of shape B x L.

Returns:

  • x_masked: masked image

  • ids_restore: the ids to restore original image

  • ids_keep: ids of the kept patches

  • ids_dump: ids of the removed patches

Return type:

Tuple[torch.Tensor, …]

forward(x, importance)[source]

Generate features for masked images.

The function supports two kind of forward behaviors. If the importance is None, the function generates mask and masks some patches randomly and get the hidden features for visible patches. The mask is generated by importance. The higher the importance, the more likely the patch is kept. The importance is calculated by CLIP. The higher the CLIP score, the more likely the patch is kept. The CLIP score is calculated by cross attention between the class token and all other tokens from the last layer. If the importance is torch.Tensor, the forward function will call super().forward(), which extract features from images without mask.

Parameters:
  • x (torch.Tensor) – Input images, which is of shape B x C x H x W.

  • importance (torch.Tensor, optional) – Importance of each patch, which is of shape B x L.

Returns:

masked image, the ids to restore original image, ids of the kept patches, ids of the removed patches.

  • x (torch.Tensor): hidden features, which is of shape B x (L * mask_ratio) x C.

  • ids_restore (torch.Tensor): ids to restore original image.

  • ids_keep (torch.Tensor): ids of the kept patches.

  • ids_dump (torch.Tensor): ids of the removed patches.

Return type:

Tuple[torch.Tensor, …]

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.