# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmengine.dist import all_reduce
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
def distributed_sinkhorn(out: torch.Tensor, sinkhorn_iterations: int,
world_size: int, epsilon: float) -> torch.Tensor:
"""Apply the distributed sinknorn optimization on the scores matrix to find
the assignments.
This function is modified from
out (torch.Tensor): The scores matrix
sinkhorn_iterations (int): Number of iterations in Sinkhorn-Knopp
world_size (int): The world size of the process group.
epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm.
torch.Tensor: Output of sinkhorn algorithm.
eps_num_stab = 1e-12
Q = torch.exp(out / epsilon).t(
) # Q is K-by-B for consistency with notations from our paper
B = Q.shape[1] * world_size # number of samples to assign
K = Q.shape[0] # how many prototypes
# make the matrix sums to 1
sum_Q = torch.sum(Q)
Q /= sum_Q
for it in range(sinkhorn_iterations):
# normalize each row: total weight per prototype must be 1/K
u = torch.sum(Q, dim=1, keepdim=True)
if len(torch.nonzero(u == 0)) > 0:
Q += eps_num_stab
u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype)
Q /= u
Q /= K
# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
Q *= B # the columns must sum to 1 so that Q is an assignment
return Q.t()
class MultiPrototypes(BaseModule):
"""Multi-prototypes for SwAV head.
output_dim (int): The output dim from SwAV neck.
num_prototypes (List[int]): The number of prototypes needed.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
def __init__(self,
output_dim: int,
num_prototypes: List[int],
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
assert isinstance(num_prototypes, list)
self.num_heads = len(num_prototypes)
for i, k in enumerate(num_prototypes):
self.add_module('prototypes' + str(i),
nn.Linear(output_dim, k, bias=False))
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Run forward for every prototype."""
out = []
for i in range(self.num_heads):
out.append(getattr(self, 'prototypes' + str(i))(x))
return out
class SwAVLoss(BaseModule):
"""The Loss for SwAV.
This Loss contains clustering and sinkhorn algorithms to compute Q codes.
Part of the code is borrowed from `script
The queue is built in `engine/hooks/`.
feat_dim (int): feature dimension of the prototypes.
sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp
algorithm. Defaults to 3.
epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm.
Defaults to 0.05.
temperature (float): temperature parameter in training loss.
Defaults to 0.1.
crops_for_assign (List[int]): list of crops id used for computing
assignments. Defaults to [0, 1].
num_crops (List[int]): list of number of crops. Defaults to [2].
num_prototypes (int): number of prototypes. Defaults to 3000.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
def __init__(self,
feat_dim: int,
sinkhorn_iterations: int = 3,
epsilon: float = 0.05,
temperature: float = 0.1,
crops_for_assign: List[int] = [0, 1],
num_crops: List[int] = [2],
num_prototypes: int = 3000,
init_cfg: Optional[Union[List[dict], dict]] = None):
self.sinkhorn_iterations = sinkhorn_iterations
self.epsilon = epsilon
self.temperature = temperature
self.crops_for_assign = crops_for_assign
self.num_crops = num_crops
self.use_queue = False
self.queue = None
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
# prototype layer
self.prototypes = None
if isinstance(num_prototypes, list):
self.prototypes = MultiPrototypes(feat_dim, num_prototypes)
elif num_prototypes > 0:
self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False)
assert self.prototypes is not None
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function of SwAV loss.
x (torch.Tensor): NxC input features.
torch.Tensor: The returned loss.
# normalize the prototypes
with torch.no_grad():
w =
w = nn.functional.normalize(w, dim=1, p=2)
embedding, output = x, self.prototypes(x)
embedding = embedding.detach()
bs = int(embedding.size(0) / sum(self.num_crops))
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id:bs * (crop_id + 1)].detach()
# time to use the queue
if self.queue is not None:
if self.use_queue or not torch.all(self.queue[i,
-1, :] == 0):
self.use_queue = True
out =
self.prototypes.weight.t()), out))
# fill the queue
self.queue[i, bs:] = self.queue[i, :-bs].clone()
self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) *
# get assignments (batch_size * num_prototypes)
q = distributed_sinkhorn(out, self.sinkhorn_iterations,
self.world_size, self.epsilon)[-bs:]
# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id):
x = output[bs * v:bs * (v + 1)] / self.temperature
subloss -= torch.mean(
torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1))
loss += subloss / (np.sum(self.num_crops) - 1)
loss /= len(self.crops_for_assign)
return loss