BEiTV2Head¶
- class mmpretrain.models.heads.BEiTV2Head(embed_dims, num_embed, loss, init_cfg={'bias': 0, 'layer': 'Linear', 'std': 0.02, 'type': 'TruncNormal'})[source]¶
Head for BEiT v2 Pre-training.
Compute the logits and the cross entropy loss.
- Parameters:
- loss(feats, feats_cls_pt, target, mask)[source]¶
Generate loss.
- Parameters:
feats (torch.Tensor) – Features from backbone.
feats_cls_pt (torch.Tensor) – Features from class late layers for pretraining.
target (torch.Tensor) – Target generated by target_generator.
mask (torch.Tensor) – Generated mask for pretraing.