MultiLabelClsHead¶
- class mmpretrain.models.heads.MultiLabelClsHead(loss={'type': 'CrossEntropyLoss', 'use_sigmoid': True}, thr=None, topk=None, init_cfg=None)[source]¶
Classification head for multilabel task.
- Parameters:
loss (dict) – Config of classification loss. Defaults to dict(type=’CrossEntropyLoss’, use_sigmoid=True).
thr (float, optional) – Predictions with scores under the thresholds are considered as negative. Defaults to None.
topk (int, optional) – Predictions with the k-th highest scores are considered as positive. Defaults to None.
init_cfg (dict, optional) – The extra init config of layers. Defaults to None.
Notes
If both
thr
andtopk
are set, usethr` to determine positive predictions. If neither is set, use ``thr=0.5
as default.- loss(feats, data_samples, **kwargs)[source]¶
Calculate losses from the classification score.
- Parameters:
feats (tuple[Tensor]) – The features extracted from the backbone. Multiple stage inputs are acceptable but only the last stage will be used to classify. The shape of every item should be
(num_samples, num_classes)
.data_samples (List[DataSample]) – The annotation data of every samples.
**kwargs – Other keyword arguments to forward the loss module.
- Returns:
a dictionary of loss components
- Return type:
- pre_logits(feats)[source]¶
The process before the final classification head.
The input
feats
is a tuple of tensor, and each tensor is the feature of a backbone stage. InMultiLabelClsHead
, we just obtain the feature of the last stage.
- predict(feats, data_samples=None)[source]¶
Inference without augmentation.
- Parameters:
feats (tuple[Tensor]) – The features extracted from the backbone. Multiple stage inputs are acceptable but only the last stage will be used to classify. The shape of every item should be
(num_samples, num_classes)
.data_samples (List[DataSample], optional) – The annotation data of every samples. If not None, set
pred_label
of the input data samples. Defaults to None.
- Returns:
A list of data samples which contains the predicted results.
- Return type:
List[DataSample]