Shortcuts

mmpretrain.evaluation.RetrievalRecall

class mmpretrain.evaluation.RetrievalRecall(topk, collect_device='cpu', prefix=None)[source]

Recall evaluation metric for image retrieval.

Parameters:
  • topk (int | Sequence[int]) – If the ground truth label matches one of the best k predictions, the sample will be regard as a positive prediction. If the parameter is a tuple, all of top-k recall will be calculated and outputted together. Defaults to 1.

  • collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.

  • prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.

Examples

Use in the code:

>>> import torch
>>> from mmpretrain.evaluation import RetrievalRecall
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [[0], [1], [2], [3]]
>>> y_true = [[0, 1], [2], [1], [0, 3]]
>>> RetrievalRecall.calculate(
>>>     y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
[tensor([50.])]
>>> # Calculate the recall@1 and recall@5 for non-indices input.
>>> y_score = torch.rand((1000, 10))
>>> import torch.nn.functional as F
>>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10)
>>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5))
[tensor(9.3000), tensor(48.4000)]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
...     DataSample().set_gt_label([0, 1]).set_pred_score(
...     torch.rand(10))
...     for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5)))
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{'retrieval/Recall@1': 20.700000762939453,
 'retrieval/Recall@5': 78.5999984741211}

Use in OpenMMLab configs:

val_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
test_evaluator = val_evaluator
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.