Shortcuts

mmcls.structures.cls_data_sample 源代码

# Copyright (c) OpenMMLab. All rights reserved.

from multiprocessing.reduction import ForkingPickler
from numbers import Number
from typing import Sequence, Union

import numpy as np
import torch
from mmengine.structures import BaseDataElement, LabelData
from mmengine.utils import is_str


def format_label(
        value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
    """Convert various python types to label-format tensor.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`, :class:`int`.

    Args:
        value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.

    Returns:
        :obj:`torch.Tensor`: The foramtted label tensor.
    """

    # Handle single number
    if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
        value = int(value.item())

    if isinstance(value, np.ndarray):
        value = torch.from_numpy(value).to(torch.long)
    elif isinstance(value, Sequence) and not is_str(value):
        value = torch.tensor(value).to(torch.long)
    elif isinstance(value, int):
        value = torch.LongTensor([value])
    elif not isinstance(value, torch.Tensor):
        raise TypeError(f'Type {type(value)} is not an available label type.')
    assert value.ndim == 1, \
        f'The dims of value should be 1, but got {value.ndim}.'

    return value


def format_score(
        value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
    """Convert various python types to score-format tensor.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`.

    Args:
        value (torch.Tensor | numpy.ndarray | Sequence): Score values.

    Returns:
        :obj:`torch.Tensor`: The foramtted score tensor.
    """

    if isinstance(value, np.ndarray):
        value = torch.from_numpy(value).float()
    elif isinstance(value, Sequence) and not is_str(value):
        value = torch.tensor(value).float()
    elif not isinstance(value, torch.Tensor):
        raise TypeError(f'Type {type(value)} is not an available label type.')
    assert value.ndim == 1, \
        f'The dims of value should be 1, but got {value.ndim}.'

    return value


[文档]class ClsDataSample(BaseDataElement): """A data structure interface of classification task. It's used as interfaces between different components. Meta fields: img_shape (Tuple): The shape of the corresponding input image. Used for visualization. ori_shape (Tuple): The original shape of the corresponding image. Used for visualization. num_classes (int): The number of all categories. Used for label format conversion. Data fields: gt_label (:obj:`~mmengine.structures.LabelData`): The ground truth label. pred_label (:obj:`~mmengine.structures.LabelData`): The predicted label. scores (torch.Tensor): The outputs of model. logits (torch.Tensor): The outputs of model without softmax nor sigmoid. Examples: >>> import torch >>> from mmcls.structures import ClsDataSample >>> >>> img_meta = dict(img_shape=(960, 720), num_classes=5) >>> data_sample = ClsDataSample(metainfo=img_meta) >>> data_sample.set_gt_label(3) >>> print(data_sample) <ClsDataSample( META INFORMATION num_classes = 5 img_shape = (960, 720) DATA FIELDS gt_label: <LabelData( META INFORMATION num_classes: 5 DATA FIELDS label: tensor([3]) ) at 0x7f21fb1b9190> ) at 0x7f21fb1b9880> >>> # For multi-label data >>> data_sample.set_gt_label([0, 1, 4]) >>> print(data_sample.gt_label) <LabelData( META INFORMATION num_classes: 5 DATA FIELDS label: tensor([0, 1, 4]) ) at 0x7fd7d1b41970> >>> # Set one-hot format score >>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]) >>> data_sample.set_pred_score(score) >>> print(data_sample.pred_label) <LabelData( META INFORMATION num_classes: 5 DATA FIELDS score: tensor([0.1, 0.1, 0.6, 0.1, 0.1]) ) at 0x7fd7d1b41970> """ def set_gt_label( self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] ) -> 'ClsDataSample': """Set label of ``gt_label``.""" label_data = getattr(self, '_gt_label', LabelData()) label_data.label = format_label(value) self.gt_label = label_data return self def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample': """Set score of ``gt_label``.""" label_data = getattr(self, '_gt_label', LabelData()) label_data.score = format_score(value) if hasattr(self, 'num_classes'): assert len(label_data.score) == self.num_classes, \ f'The length of score {len(label_data.score)} should be '\ f'equal to the num_classes {self.num_classes}.' else: self.set_field( name='num_classes', value=len(label_data.score), field_type='metainfo') self.gt_label = label_data return self def set_pred_label( self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] ) -> 'ClsDataSample': """Set label of ``pred_label``.""" label_data = getattr(self, '_pred_label', LabelData()) label_data.label = format_label(value) self.pred_label = label_data return self def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample': """Set score of ``pred_label``.""" label_data = getattr(self, '_pred_label', LabelData()) label_data.score = format_score(value) if hasattr(self, 'num_classes'): assert len(label_data.score) == self.num_classes, \ f'The length of score {len(label_data.score)} should be '\ f'equal to the num_classes {self.num_classes}.' else: self.set_field( name='num_classes', value=len(label_data.score), field_type='metainfo') self.pred_label = label_data return self @property def gt_label(self): return self._gt_label @gt_label.setter def gt_label(self, value: LabelData): self.set_field(value, '_gt_label', dtype=LabelData) @gt_label.deleter def gt_label(self): del self._gt_label @property def pred_label(self): return self._pred_label @pred_label.setter def pred_label(self, value: LabelData): self.set_field(value, '_pred_label', dtype=LabelData) @pred_label.deleter def pred_label(self): del self._pred_label
def _reduce_cls_datasample(data_sample): """reduce ClsDataSample.""" attr_dict = data_sample.__dict__ convert_keys = [] for k, v in attr_dict.items(): if isinstance(v, LabelData): attr_dict[k] = v.numpy() convert_keys.append(k) return _rebuild_cls_datasample, (attr_dict, convert_keys) def _rebuild_cls_datasample(attr_dict, convert_keys): """rebuild ClsDataSample.""" data_sample = ClsDataSample() for k in convert_keys: attr_dict[k] = attr_dict[k].to_tensor() data_sample.__dict__ = attr_dict return data_sample # Due to the multi-processing strategy of PyTorch, ClsDataSample may consume # many file descriptors because it contains multiple LabelData with tensors. # Here we overwrite the reduce function of ClsDataSample in ForkingPickler and # convert these tensors to np.ndarray during pickling. It may influence the # performance of dataloader, but slightly because these tensors in LabelData # are very small. ForkingPickler.register(ClsDataSample, _reduce_cls_datasample)
Read the Docs v: mmcls-1.x
Versions
latest
stable
mmcls-1.x
mmcls-0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.