Shortcuts

CAEHead

class mmpretrain.models.heads.CAEHead(loss, init_cfg=None)[source]

Head for CAE Pre-training.

Compute the align loss and the main loss. In addition, this head also generates the prediction target generated by dalle.

Parameters:
  • loss (dict) – The config of loss.

  • tokenizer_path (str) – The path of the tokenizer.

  • init_cfg (dict or List[dict], optional) – Initialization config dict. Defaults to None.

loss(logits, logits_target, latent_pred, latent_target, mask)[source]

Generate loss.

Parameters:
  • logits (torch.Tensor) – Logits generated by decoder.

  • logits_target (img_target) – Target generated by dalle for decoder prediction.

  • latent_pred (torch.Tensor) – Latent prediction by regressor.

  • latent_target (torch.Tensor) – Target for latent prediction, generated by teacher.

Returns:

The tuple of loss.
  • loss_main (torch.Tensor): Cross entropy loss.

  • loss_align (torch.Tensor): MSE loss.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.