TimmClassifier¶
- class mmpretrain.models.classifiers.TimmClassifier(*args, loss={'loss_weight': 1.0, 'type': 'CrossEntropyLoss'}, train_cfg=None, with_cp=False, data_preprocessor=None, init_cfg=None, **kwargs)[source]¶
Image classifiers for pytorch-image-models (timm) model.
This class accepts all positional and keyword arguments of the function timm.models.create_model and use it to create a model from pytorch-image-models.
It can load checkpoints of timm directly, and the saved checkpoints also can be directly load by timm.
Please confirm that you have installed
timm
if you want to use it.- Parameters:
*args – All positional arguments of the function timm.models.create_model.
loss (dict) – Config of classification loss. Defaults to
dict(type='CrossEntropyLoss', loss_weight=1.0)
.train_cfg (dict, optional) –
The training setting. The acceptable fields are:
augments (List[dict]): The batch augmentation methods to use. More details can be found in
mmpretrain.model.utils.augment
.
Defaults to None.
with_cp (bool) – Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Defaults to False.
data_preprocessor (dict, optional) – The config for preprocessing input data. If None or no specified type, it will use “ClsDataPreprocessor” as type. See
ClsDataPreprocessor
for more details. Defaults to None.init_cfg (dict, optional) – the config to control the initialization. Defaults to None.
**kwargs – Other keyword arguments of the function timm.models.create_model.
Examples
>>> import torch >>> from mmpretrain.models import build_classifier >>> cfg = dict(type='TimmClassifier', model_name='resnet50', pretrained=True) >>> model = build_classifier(cfg) >>> inputs = torch.rand(1, 3, 224, 224) >>> out = model(inputs) >>> print(out.shape) torch.Size([1, 1000])
- loss(inputs, data_samples, **kwargs)[source]¶
Calculate losses from a batch of inputs and data samples.
- Parameters:
inputs (torch.Tensor) – The input tensor with shape (N, C, …) in general.
data_samples (List[DataSample]) – The annotation data of every samples.
**kwargs – Other keyword arguments of the loss module.
- Returns:
a dictionary of loss components
- Return type:
- predict(inputs, data_samples=None)[source]¶
Predict results from a batch of inputs.
- Parameters:
inputs (torch.Tensor) – The input tensor with shape (N, C, …) in general.
data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.
- Returns:
The prediction results.
- Return type:
List[DataSample]