SwAVLoss¶
- class mmpretrain.models.losses.SwAVLoss(feat_dim, sinkhorn_iterations=3, epsilon=0.05, temperature=0.1, crops_for_assign=[0, 1], num_crops=[2], num_prototypes=3000, init_cfg=None)[source]¶
The Loss for SwAV.
This Loss contains clustering and sinkhorn algorithms to compute Q codes. Part of the code is borrowed from script. The queue is built in engine/hooks/swav_hook.py.
- Parameters:
feat_dim (int) – feature dimension of the prototypes.
sinkhorn_iterations (int) – number of iterations in Sinkhorn-Knopp algorithm. Defaults to 3.
epsilon (float) – regularization parameter for Sinkhorn-Knopp algorithm. Defaults to 0.05.
temperature (float) – temperature parameter in training loss. Defaults to 0.1.
crops_for_assign (List[int]) – list of crops id used for computing assignments. Defaults to [0, 1].
num_crops (List[int]) – list of number of crops. Defaults to [2].
num_prototypes (int) – number of prototypes. Defaults to 3000.
init_cfg (dict or List[dict], optional) – Initialization config dict. Defaults to None.
- forward(x)[source]¶
Forward function of SwAV loss.
- Parameters:
x (torch.Tensor) – NxC input features.
- Returns:
The returned loss.
- Return type: