Shortcuts

mmpretrain.models.heads.mim_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
from mmengine.model import BaseModule

from mmpretrain.registry import MODELS


[文档]@MODELS.register_module() class MIMHead(BaseModule): """Pre-training head for Masked Image Modeling. Args: loss (dict): Config dict for module of loss functions. """ def __init__(self, loss: dict) -> None: super().__init__() self.loss_module = MODELS.build(loss)
[文档] def loss(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward head. Args: pred (torch.Tensor): Predictions with shape B x L x C. target (torch.Tensor): Targets with shape B x L x C. mask (torch.Tensor): Mask with shape B x L. Returns: torch.Tensor: The loss tensor. """ loss = self.loss_module(pred, target, mask) return loss
Read the Docs v: dev
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.