Shortcuts

CAELoss

class mmpretrain.models.losses.CAELoss(lambd)[源代码]

Loss function for CAE.

Compute the align loss and the main loss.

参数:

lambd (float) – The weight for the align loss.

forward(logits, target, latent_pred, latent_target)[源代码]

Forward function of CAE Loss.

参数:
  • logits (torch.Tensor) – The outputs from the decoder.

  • target (torch.Tensor) – The targets generated by dalle.

  • latent_pred (torch.Tensor) – The latent prediction from the regressor.

  • latent_target (torch.Tensor) – The latent target from the teacher network.

返回:

The main loss and align loss.

返回类型:

Tuple[torch.Tensor, torch.Tensor]

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.