Shortcuts

ConfusionMatrix

class mmpretrain.evaluation.ConfusionMatrix(num_classes=None, collect_device='cpu', prefix=None)[源代码]

A metric to calculate confusion matrix for single-label tasks.

参数:
  • num_classes (int, optional) – The number of classes. Defaults to None.

  • collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.

  • prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.

示例

  1. The basic usage.

>>> import torch
>>> from mmpretrain.evaluation import ConfusionMatrix
>>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3]
>>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)
tensor([[1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1]])
>>> # plot the confusion matrix
>>> import matplotlib.pyplot as plt
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.randint(10, (1000, ))
>>> matrix = ConfusionMatrix.calculate(y_score, y_true)
>>> ConfusionMatrix().plot(matrix)
>>> plt.show()
  1. In the config file

val_evaluator = dict(type='ConfusionMatrix')
test_evaluator = dict(type='ConfusionMatrix')
static calculate(pred, target, num_classes=None)[源代码]

Calculate the confusion matrix for single-label task.

参数:
  • pred (torch.Tensor | np.ndarray | Sequence) – The prediction results. It can be labels (N, ), or scores of every class (N, C).

  • target (torch.Tensor | np.ndarray | Sequence) – The target of each prediction with shape (N, ).

  • num_classes (Optional, int) – The number of classes. If the pred is label instead of scores, this argument is required. Defaults to None.

返回:

The confusion matrix.

返回类型:

torch.Tensor

static plot(confusion_matrix, include_values=False, cmap='viridis', classes=None, colorbar=True, show=True)[源代码]

Draw a confusion matrix by matplotlib.

Modified from Scikit-Learn

参数:
  • confusion_matrix (torch.Tensor) – The confusion matrix to draw.

  • include_values (bool) – Whether to draw the values in the figure. Defaults to False.

  • cmap (str) – The color map to use. Defaults to use “viridis”.

  • classes (list[str], optional) – The names of categories. Defaults to None, which means to use index number.

  • colorbar (bool) – Whether to show the colorbar. Defaults to True.

  • show (bool) – Whether to show the figure immediately. Defaults to True.

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.