CAEHead¶
- class mmpretrain.models.heads.CAEHead(loss, init_cfg=None)[source]¶
Head for CAE Pre-training.
Compute the align loss and the main loss. In addition, this head also generates the prediction target generated by dalle.
- Parameters:
- loss(logits, logits_target, latent_pred, latent_target, mask)[source]¶
Generate loss.
- Parameters:
logits (torch.Tensor) – Logits generated by decoder.
logits_target (img_target) – Target generated by dalle for decoder prediction.
latent_pred (torch.Tensor) – Latent prediction by regressor.
latent_target (torch.Tensor) – Target for latent prediction, generated by teacher.
- Returns:
- The tuple of loss.
loss_main
(torch.Tensor): Cross entropy loss.loss_align
(torch.Tensor): MSE loss.
- Return type:
Tuple[torch.Tensor, torch.Tensor]