Shortcuts

Source code for mmpretrain.models.losses.cosine_similarity_loss

# Copyright (c) OpenMMLab. All rights reserved.

from typing import Optional

import torch
from mmengine.model import BaseModule
from torch import nn

from mmpretrain.registry import MODELS


[docs]@MODELS.register_module() class CosineSimilarityLoss(BaseModule): """Cosine similarity loss function. Compute the similarity between two features and optimize that similarity as loss. Args: shift_factor (float): The shift factor of cosine similarity. Default: 0.0. scale_factor (float): The scale factor of cosine similarity. Default: 1.0. """ def __init__(self, shift_factor: float = 0.0, scale_factor: float = 1.0) -> None: super().__init__() self.shift_factor = shift_factor self.scale_factor = scale_factor
[docs] def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward function of cosine similarity loss. Args: pred (torch.Tensor): The predicted features. target (torch.Tensor): The target features. Returns: torch.Tensor: The cosine similarity loss. """ pred_norm = nn.functional.normalize(pred, dim=-1) target_norm = nn.functional.normalize(target, dim=-1) loss = self.shift_factor - self.scale_factor * ( pred_norm * target_norm).sum(dim=-1) if mask is None: loss = loss.mean() else: loss = (loss * mask).sum() / mask.sum() return loss
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.