Shortcuts

TimmClassifier

class mmpretrain.models.classifiers.TimmClassifier(*args, loss={'loss_weight': 1.0, 'type': 'CrossEntropyLoss'}, train_cfg=None, with_cp=False, data_preprocessor=None, init_cfg=None, **kwargs)[源代码]

Image classifiers for pytorch-image-models (timm) model.

This class accepts all positional and keyword arguments of the function timm.models.create_model and use it to create a model from pytorch-image-models.

It can load checkpoints of timm directly, and the saved checkpoints also can be directly load by timm.

Please confirm that you have installed timm if you want to use it.

参数:
  • *args – All positional arguments of the function timm.models.create_model.

  • loss (dict) – Config of classification loss. Defaults to dict(type='CrossEntropyLoss', loss_weight=1.0).

  • train_cfg (dict, optional) –

    The training setting. The acceptable fields are:

    • augments (List[dict]): The batch augmentation methods to use. More details can be found in mmpretrain.model.utils.augment.

    Defaults to None.

  • with_cp (bool) – Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Defaults to False.

  • data_preprocessor (dict, optional) – The config for preprocessing input data. If None or no specified type, it will use “ClsDataPreprocessor” as type. See ClsDataPreprocessor for more details. Defaults to None.

  • init_cfg (dict, optional) – the config to control the initialization. Defaults to None.

  • **kwargs – Other keyword arguments of the function timm.models.create_model.

示例

>>> import torch
>>> from mmpretrain.models import build_classifier
>>> cfg = dict(type='TimmClassifier', model_name='resnet50', pretrained=True)
>>> model = build_classifier(cfg)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> out = model(inputs)
>>> print(out.shape)
torch.Size([1, 1000])
loss(inputs, data_samples, **kwargs)[源代码]

Calculate losses from a batch of inputs and data samples.

参数:
  • inputs (torch.Tensor) – The input tensor with shape (N, C, …) in general.

  • data_samples (List[DataSample]) – The annotation data of every samples.

  • **kwargs – Other keyword arguments of the loss module.

返回:

a dictionary of loss components

返回类型:

dict[str, Tensor]

predict(inputs, data_samples=None)[源代码]

Predict results from a batch of inputs.

参数:
  • inputs (torch.Tensor) – The input tensor with shape (N, C, …) in general.

  • data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.

返回:

The prediction results.

返回类型:

List[DataSample]

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.