Shortcuts

mmcls.apis.get_model

mmcls.apis.get_model(model_name, pretrained=False, device=None, **kwargs)[源代码]

Get a pre-defined model by the name of model.

参数
  • model_name (str) – The name of model.

  • pretrained (bool | str) – If True, load the pre-defined pretrained weights. If a string, load the weights from it. Defaults to False.

  • device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.

  • **kwargs – Other keyword arguments of the model config.

返回

The result model.

返回类型

mmengine.model.BaseModel

使用示例

Get a ResNet-50 model and extract images feature:

>>> import torch
>>> from mmcls import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
>>> for feat in feats:
...     print(feat.shape)
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
torch.Size([16, 2048])

Get Swin-Transformer model with pre-trained weights and inference:

>>> from mmcls import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
'sea snake'
Read the Docs v: mmcls-1.x
Versions
latest
stable
mmcls-1.x
mmcls-0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.