mmpretrain.evaluation.RetrievalRecall¶
- class mmpretrain.evaluation.RetrievalRecall(topk, collect_device='cpu', prefix=None)[source]¶
Recall evaluation metric for image retrieval.
- Parameters:
topk (int | Sequence[int]) – If the ground truth label matches one of the best k predictions, the sample will be regard as a positive prediction. If the parameter is a tuple, all of top-k recall will be calculated and outputted together. Defaults to 1.
collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.
prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.
Examples
Use in the code:
>>> import torch >>> from mmpretrain.evaluation import RetrievalRecall >>> # -------------------- The Basic Usage -------------------- >>> y_pred = [[0], [1], [2], [3]] >>> y_true = [[0, 1], [2], [1], [0, 3]] >>> RetrievalRecall.calculate( >>> y_pred, y_true, topk=1, pred_indices=True, target_indices=True) [tensor([50.])] >>> # Calculate the recall@1 and recall@5 for non-indices input. >>> y_score = torch.rand((1000, 10)) >>> import torch.nn.functional as F >>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10) >>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5)) [tensor(9.3000), tensor(48.4000)] >>> >>> # ------------------- Use with Evalutor ------------------- >>> from mmpretrain.structures import DataSample >>> from mmengine.evaluator import Evaluator >>> data_samples = [ ... DataSample().set_gt_label([0, 1]).set_pred_score( ... torch.rand(10)) ... for i in range(1000) ... ] >>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5))) >>> evaluator.process(data_samples) >>> evaluator.evaluate(1000) {'retrieval/Recall@1': 20.700000762939453, 'retrieval/Recall@5': 78.5999984741211}
Use in OpenMMLab configs:
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5)) test_evaluator = val_evaluator