ConfusionMatrix¶
- class mmpretrain.evaluation.ConfusionMatrix(num_classes=None, collect_device='cpu', prefix=None)[source]¶
A metric to calculate confusion matrix for single-label tasks.
- Parameters:
num_classes (int, optional) – The number of classes. Defaults to None.
collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.
prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.
Examples
The basic usage.
>>> import torch >>> from mmpretrain.evaluation import ConfusionMatrix >>> y_pred = [0, 1, 1, 3] >>> y_true = [0, 2, 1, 3] >>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4) tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) >>> # plot the confusion matrix >>> import matplotlib.pyplot as plt >>> y_score = torch.rand((1000, 10)) >>> y_true = torch.randint(10, (1000, )) >>> matrix = ConfusionMatrix.calculate(y_score, y_true) >>> ConfusionMatrix().plot(matrix) >>> plt.show()
In the config file
val_evaluator = dict(type='ConfusionMatrix') test_evaluator = dict(type='ConfusionMatrix')
- static calculate(pred, target, num_classes=None)[source]¶
Calculate the confusion matrix for single-label task.
- Parameters:
pred (torch.Tensor | np.ndarray | Sequence) – The prediction results. It can be labels (N, ), or scores of every class (N, C).
target (torch.Tensor | np.ndarray | Sequence) – The target of each prediction with shape (N, ).
num_classes (Optional, int) – The number of classes. If the
pred
is label instead of scores, this argument is required. Defaults to None.
- Returns:
The confusion matrix.
- Return type:
- static plot(confusion_matrix, include_values=False, cmap='viridis', classes=None, colorbar=True, show=True)[source]¶
Draw a confusion matrix by matplotlib.
Modified from Scikit-Learn
- Parameters:
confusion_matrix (torch.Tensor) – The confusion matrix to draw.
include_values (bool) – Whether to draw the values in the figure. Defaults to False.
cmap (str) – The color map to use. Defaults to use “viridis”.
classes (list[str], optional) – The names of categories. Defaults to None, which means to use index number.
colorbar (bool) – Whether to show the colorbar. Defaults to True.
show (bool) – Whether to show the figure immediately. Defaults to True.