PixelReconstructionLoss¶
- class mmpretrain.models.losses.PixelReconstructionLoss(criterion, channel=None)[源代码]¶
Loss for the reconstruction of pixel in Masked Image Modeling.
This module measures the distance between the target image and the reconstructed image and compute the loss to optimize the model. Currently, This module only provides L1 and L2 loss to penalize the reconstructed error. In addition, a mask can be passed in the
forward
function to only apply loss on visible region, like that in MAE.- 参数:
- forward(pred, target, mask=None)[源代码]¶
Forward function to compute the reconstrction loss.
- 参数:
pred (torch.Tensor) – The reconstructed image.
target (torch.Tensor) – The target image.
mask (torch.Tensor) – The mask of the target image.
- 返回:
The reconstruction loss.
- 返回类型: