Shortcuts

mmcls.structures

该包中包含了用于分类任务的基础数据结构。

ClsDataSample

class mmcls.structures.ClsDataSample(*, metainfo=None, **kwargs)[源代码]

A data structure interface of classification task.

It’s used as interfaces between different components.

Meta fields
  • img_shape (Tuple) – The shape of the corresponding input image. Used for visualization.

  • ori_shape (Tuple) – The original shape of the corresponding image. Used for visualization.

  • num_classes (int) – The number of all categories. Used for label format conversion.

Data fields
  • gt_label (LabelData) – The ground truth label.

  • pred_label (LabelData) – The predicted label.

  • scores (torch.Tensor) – The outputs of model.

  • logits (torch.Tensor) – The outputs of model without softmax nor sigmoid.

使用示例

>>> import torch
>>> from mmcls.structures import ClsDataSample
>>>
>>> img_meta = dict(img_shape=(960, 720), num_classes=5)
>>> data_sample = ClsDataSample(metainfo=img_meta)
>>> data_sample.set_gt_label(3)
>>> print(data_sample)
<ClsDataSample(
   META INFORMATION
   num_classes = 5
   img_shape = (960, 720)
   DATA FIELDS
   gt_label: <LabelData(
           META INFORMATION
           num_classes: 5
           DATA FIELDS
           label: tensor([3])
       ) at 0x7f21fb1b9190>
) at 0x7f21fb1b9880>
>>> # For multi-label data
>>> data_sample.set_gt_label([0, 1, 4])
>>> print(data_sample.gt_label)
<LabelData(
    META INFORMATION
    num_classes: 5
    DATA FIELDS
    label: tensor([0, 1, 4])
) at 0x7fd7d1b41970>
>>> # Set one-hot format score
>>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])
>>> data_sample.set_pred_score(score)
>>> print(data_sample.pred_label)
<LabelData(
    META INFORMATION
    num_classes: 5
    DATA FIELDS
    score: tensor([0.1, 0.1, 0.6, 0.1, 0.1])
) at 0x7fd7d1b41970>
Read the Docs v: mmcls-1.x
Versions
latest
stable
mmcls-1.x
mmcls-0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.