Shortcuts

SwAVLoss

class mmpretrain.models.losses.SwAVLoss(feat_dim, sinkhorn_iterations=3, epsilon=0.05, temperature=0.1, crops_for_assign=[0, 1], num_crops=[2], num_prototypes=3000, init_cfg=None)[source]

The Loss for SwAV.

This Loss contains clustering and sinkhorn algorithms to compute Q codes. Part of the code is borrowed from script. The queue is built in engine/hooks/swav_hook.py.

Parameters:
  • feat_dim (int) – feature dimension of the prototypes.

  • sinkhorn_iterations (int) – number of iterations in Sinkhorn-Knopp algorithm. Defaults to 3.

  • epsilon (float) – regularization parameter for Sinkhorn-Knopp algorithm. Defaults to 0.05.

  • temperature (float) – temperature parameter in training loss. Defaults to 0.1.

  • crops_for_assign (List[int]) – list of crops id used for computing assignments. Defaults to [0, 1].

  • num_crops (List[int]) – list of number of crops. Defaults to [2].

  • num_prototypes (int) – number of prototypes. Defaults to 3000.

  • init_cfg (dict or List[dict], optional) – Initialization config dict. Defaults to None.

forward(x)[source]

Forward function of SwAV loss.

Parameters:

x (torch.Tensor) – NxC input features.

Returns:

The returned loss.

Return type:

torch.Tensor

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.