mmpretrain.structures¶
This package includes basic data structures.
DataSample¶
- class mmpretrain.structures.DataSample(*, metainfo=None, **kwargs)[源代码]¶
A general data structure interface.
It’s used as the interface between different components.
The following fields are convention names in MMPretrain, and we will set or get these fields in data transforms, models, and metrics if needed. You can also set any new fields for your need.
- Meta fields:
img_shape (Tuple) – The shape of the corresponding input image.
ori_shape (Tuple) – The original shape of the corresponding image.
sample_idx (int) – The index of the sample in the dataset.
num_classes (int) – The number of all categories.
- Data fields:
gt_label (tensor) – The ground truth label.
gt_score (tensor) – The ground truth score.
pred_label (tensor) – The predicted label.
pred_score (tensor) – The predicted score.
mask (tensor) – The mask used in masked image modeling.
示例
>>> import torch >>> from mmpretrain.structures import DataSample >>> >>> img_meta = dict(img_shape=(960, 720), num_classes=5) >>> data_sample = DataSample(metainfo=img_meta) >>> data_sample.set_gt_label(3) >>> print(data_sample) <DataSample( META INFORMATION num_classes: 5 img_shape: (960, 720) DATA FIELDS gt_label: tensor([3]) ) at 0x7ff64c1c1d30> >>> >>> # For multi-label data >>> data_sample = DataSample().set_gt_label([0, 1, 4]) >>> print(data_sample) <DataSample( DATA FIELDS gt_label: tensor([0, 1, 4]) ) at 0x7ff5b490e100> >>> >>> # Set one-hot format score >>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1]) >>> print(data_sample) <DataSample( META INFORMATION num_classes: 4 DATA FIELDS pred_score: tensor([0.1000, 0.1000, 0.6000, 0.1000]) ) at 0x7ff5b48ef6a0> >>> >>> # Set custom field >>> data_sample = DataSample() >>> data_sample.my_field = [1, 2, 3] >>> print(data_sample) <DataSample( DATA FIELDS my_field: [1, 2, 3] ) at 0x7f8e9603d3a0> >>> print(data_sample.my_field) [1, 2, 3]