Shortcuts

CAENeck

class mmpretrain.models.necks.CAENeck(num_classes=8192, embed_dims=768, regressor_depth=6, decoder_depth=8, num_heads=12, mlp_ratio=4, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_cfg={'eps': 1e-06, 'type': 'LN'}, layer_scale_init_value=None, mask_tokens_num=75, init_cfg=None)[source]

Neck for CAE Pre-training.

This module construct the latent prediction regressor and the decoder for the latent prediction and final prediction.

Parameters:
  • num_classes (int) – The number of classes for final prediction. Defaults to 8192.

  • embed_dims (int) – The embed dims of latent feature in regressor and decoder. Defaults to 768.

  • regressor_depth (int) – The number of regressor blocks. Defaults to 6.

  • decoder_depth (int) – The number of decoder blocks. Defaults to 8.

  • num_heads (int) – The number of head in multi-head attention. Defaults to 12.

  • mlp_ratio (int) – The expand ratio of latent features in MLP. defaults to 4.

  • qkv_bias (bool) – Whether or not to use qkv bias. Defaults to True.

  • qk_scale (float, optional) – The scale applied to the results of qk. Defaults to None.

  • drop_rate (float) – The dropout rate. Defaults to 0.

  • attn_drop_rate (float) – The dropout rate in attention block. Defaults to 0.

  • norm_cfg (dict) – The config of normalization layer. Defaults to dict(type=’LN’, eps=1e-6).

  • layer_scale_init_value (float, optional) – The init value of gamma. Defaults to None.

  • mask_tokens_num (int) – The number of mask tokens. Defaults to 75.

  • init_cfg (dict, optional) – Initialization config dict. Defaults to None.

forward(x_unmasked, pos_embed_masked, pos_embed_unmasked)[source]

Get the latent prediction and final prediction.

Parameters:
  • x_unmasked (torch.Tensor) – Features of unmasked tokens.

  • pos_embed_masked (torch.Tensor) – Position embedding of masked tokens.

  • pos_embed_unmasked (torch.Tensor) – Position embedding of unmasked tokens.

Returns:

  • logits: Final prediction.

  • latent_pred: Latent prediction.

Return type:

Tuple[torch.Tensor, torch.Tensor]

init_weights()[source]

Initialization.

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.