Shortcuts

Source code for mmpretrain.models.heads.beitv1_head

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union

import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmpretrain.registry import MODELS


[docs]@MODELS.register_module() class BEiTV1Head(BaseModule): """Head for BEiT v1 Pre-training. Compute the logits and the cross entropy loss. Args: embed_dims (int): The dimension of embedding. num_embed (int): The number of classification types. loss (dict): The config of loss. init_cfg (dict or List[dict], optional): Initialization config dict. Defaults to None. """ def __init__( self, embed_dims: int, num_embed: int, loss: dict, init_cfg: Optional[Union[dict, List[dict]]] = dict( type='TruncNormal', layer='Linear', std=0.02, bias=0) ) -> None: super().__init__(init_cfg=init_cfg) self.cls_head = nn.Linear(embed_dims, num_embed) self.loss_module = MODELS.build(loss)
[docs] def loss(self, feats: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Generate loss. Args: feats (torch.Tensor): Features from backbone. target (torch.Tensor): Target generated by target_generator. mask (torch.Tensor): Generated mask for pretraing. """ mask = mask.flatten(1).to(torch.bool) target = torch.argmax(target, dim=1).flatten(1) target = target[mask] # remove cls_token feats = feats[:, 1:] logits = self.cls_head(feats[mask]) loss = self.loss_module(logits, target) return loss
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.