Shortcuts

# Copyright (c) OpenMMLab. All rights reserved
import numpy as np
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmpretrain.registry import HOOKS

[文档]@HOOKS.register_module()
r"""Set adaptive-margins in ArcFaceClsHead based on the power of
category-wise count.

A PyTorch implementation of paper Google Landmark Recognition 2020
Competition Third Place Solution <https://arxiv.org/abs/2010.05350>_.
The margins will be
:math:\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin.
The n indicates the number of occurrences of a category.

Args:
margin_min (float): Lower bound of margins. Defaults to 0.05.
margin_max (float): Upper bound of margins. Defaults to 0.5.
power (float): The power of category freqercy. Defaults to -0.25.
"""

def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None:
self.margin_min = margin_min
self.margin_max = margin_max
self.margin_range = margin_max - margin_min
self.p = power

[文档]    def before_train(self, runner):
"""change the margins in ArcFaceClsHead.

Args:
runner (obj: Runner): Runner.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module

raise ValueError(
'Hook SetFreqPowAdvMarginsHook could only be used '
f'for ArcFaceClsHead, but get {type(model.head)}')

# generate margins base on the dataset.
label_count = np.bincount(gt_labels)
label_count[label_count == 0] = 1  # At least one occurrence
pow_freq = np.power(label_count, self.p)

min_f, max_f = pow_freq.min(), pow_freq.max()
normized_pow_freq = (pow_freq - min_f) / (max_f - min_f)
margins = normized_pow_freq * self.margin_range + self.margin_min

assert len(margins) == runner.model.head.num_classes



© Copyright 2020, OpenMMLab. Revision 17a886cb.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev