CAELoss¶
- class mmpretrain.models.losses.CAELoss(lambd)[source]¶
Loss function for CAE.
Compute the align loss and the main loss.
- Parameters:
lambd (float) – The weight for the align loss.
- forward(logits, target, latent_pred, latent_target)[source]¶
Forward function of CAE Loss.
- Parameters:
logits (torch.Tensor) – The outputs from the decoder.
target (torch.Tensor) – The targets generated by dalle.
latent_pred (torch.Tensor) – The latent prediction from the regressor.
latent_target (torch.Tensor) – The latent target from the teacher network.
- Returns:
The main loss and align loss.
- Return type:
Tuple[torch.Tensor, torch.Tensor]