# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple

import mmcv
import numpy as np
from mmengine.dist import master_only
from mmengine.visualization import Visualizer

from mmcls.registry import VISUALIZERS
from mmcls.structures import ClsDataSample

def _get_adaptive_scale(img_shape: Tuple[int, int],
                        min_scale: float = 0.3,
                        max_scale: float = 3.0) -> float:
    """Get adaptive scale according to image shape.

    The target scale depends on the the short edge length of the image. If the
    short edge length equals 224, the output is 1.0. And output linear scales
    according the short edge length.

    You can also specify the minimum scale and the maximum scale to limit the
    linear scale.

        img_shape (Tuple[int, int]): The shape of the canvas image.
        min_size (int): The minimum scale. Defaults to 0.3.
        max_size (int): The maximum scale. Defaults to 3.0.

        int: The adaptive scale.
    short_edge_length = min(img_shape)
    scale = short_edge_length / 224.
    return min(max(scale, min_scale), max_scale)

[文档]@VISUALIZERS.register_module() class ClsVisualizer(Visualizer): """Universal Visualizer for classification task. Args: name (str): Name of the instance. Defaults to 'visualizer'. image (np.ndarray, optional): the origin image to draw. The format should be RGB. Defaults to None. vis_backends (list, optional): Visual backend config list. Defaults to None. save_dir (str, optional): Save file dir for all storage backends. If it is None, the backend storage will not save any data. fig_save_cfg (dict): Keyword parameters of figure for saving. Defaults to empty dict. fig_show_cfg (dict): Keyword parameters of figure for showing. Defaults to empty dict. Examples: >>> import torch >>> import mmcv >>> from pathlib import Path >>> from mmcls.visualization import ClsVisualizer >>> from mmcls.structures import ClsDataSample >>> # Example image >>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb') >>> # Example annotation >>> data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\ ... set_pred_score(torch.tensor([0.1, 0.8, 0.1])) >>> # Setup the visualizer >>> vis = ClsVisualizer( ... save_dir="./outputs", ... vis_backends=[dict(type='LocalVisBackend')]) >>> # Set classes names >>> vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']} >>> # Show the example image with annotation in a figure. >>> # And it will ignore all preset storage backends. >>> vis.add_datasample('res', img, data_sample, show=True) >>> # Save the visualization result by the specified storage backends. >>> vis.add_datasample('res', img, data_sample) >>> assert Path('./outputs/vis_data/vis_image/res_0.png').exists() >>> # Save another visualization result with the same name. >>> vis.add_datasample('res', img, data_sample, step=1) >>> assert Path('./outputs/vis_data/vis_image/res_1.png').exists() """ @master_only def add_datasample(self, name: str, image: np.ndarray, data_sample: Optional[ClsDataSample] = None, draw_gt: bool = True, draw_pred: bool = True, draw_score: bool = True, rescale_factor: Optional[float] = None, show: bool = False, text_cfg: dict = dict(), wait_time: float = 0, out_file: Optional[str] = None, step: int = 0) -> None: """Draw datasample and save to all backends. - If ``out_file`` is specified, all storage backends are ignored and save the image to the ``out_file``. - If ``show`` is True, plot the result image in a window, please confirm you are able to access the graphical interface. Args: name (str): The image identifier. image (np.ndarray): The image to draw. data_sample (:obj:`ClsDataSample`, optional): The annotation of the image. Defaults to None. draw_gt (bool): Whether to draw ground truth labels. Defaults to True. draw_pred (bool): Whether to draw prediction labels. Defaults to True. draw_score (bool): Whether to draw the prediction scores of prediction categories. Defaults to True. rescale_factor (float, optional): Rescale the image by the rescale factor before visualization. Defaults to None. show (bool): Whether to display the drawn image. Defaults to False. text_cfg (dict): Extra text setting, which accepts arguments of :attr:`mmengine.Visualizer.draw_texts`. Defaults to an empty dict. wait_time (float): The interval of show (s). Defaults to 0, which means "forever". out_file (str, optional): Extra path to save the visualization result. If specified, the visualizer will only save the result image to the out_file and ignore its storage backends. Defaults to None. step (int): Global step value to record. Defaults to 0. """ classes = None if self.dataset_meta is not None: classes = self.dataset_meta.get('classes', None) if rescale_factor is not None: image = mmcv.imrescale(image, rescale_factor) texts = [] self.set_image(image) if draw_gt and 'gt_label' in data_sample: gt_label = data_sample.gt_label idx = gt_label.label.tolist() class_labels = [''] * len(idx) if classes is not None: class_labels = [f' ({classes[i]})' for i in idx] labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))] prefix = 'Ground truth: ' texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) if draw_pred and 'pred_label' in data_sample: pred_label = data_sample.pred_label idx = pred_label.label.tolist() score_labels = [''] * len(idx) class_labels = [''] * len(idx) if draw_score and 'score' in pred_label: score_labels = [ f', {pred_label.score[i].item():.2f}' for i in idx ] if classes is not None: class_labels = [f' ({classes[i]})' for i in idx] labels = [ str(idx[i]) + score_labels[i] + class_labels[i] for i in range(len(idx)) ] prefix = 'Prediction: ' texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) img_scale = _get_adaptive_scale(image.shape[:2]) text_cfg = { 'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32), 'font_sizes': int(img_scale * 7), 'font_families': 'monospace', 'colors': 'white', 'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'), **text_cfg } self.draw_texts('\n'.join(texts), **text_cfg) drawn_img = self.get_image() if show:, win_name=name, wait_time=wait_time) if out_file is not None: # save the image to the target file instead of vis_backends mmcv.imwrite(drawn_img[..., ::-1], out_file) else: self.add_image(name, drawn_img, step=step)
