Shortcuts

mmpretrain.apis.get_model

mmpretrain.apis.get_model(model, pretrained=False, device=None, device_map=None, offload_folder=None, url_mapping=None, **kwargs)[source]

Get a pre-defined model or create a model from config.

Parameters:
  • model (str | Config) – The name of model, the config file path or a config instance.

  • pretrained (bool | str) – When use name to specify model, you can use True to load the pre-defined pretrained weights. And you can also use a string to specify the path or link of weights to load. Defaults to False.

  • device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.

  • device_map (str | dict | None) – A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. You can use device_map=”auto” to automatically generate the device map. Defaults to None.

  • offload_folder (str | None) – If the device_map contains any value “disk”, the folder where we will offload weights.

  • url_mapping (Tuple[str, str], optional) – The mapping of pretrained checkpoint link. For example, load checkpoint from a local dir instead of download by ('https://.*/', './checkpoint'). Defaults to None.

  • **kwargs – Other keyword arguments of the model config.

Returns:

The result model.

Return type:

mmengine.model.BaseModel

Examples

Get a ResNet-50 model and extract images feature:

>>> import torch
>>> from mmpretrain import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
>>> for feat in feats:
...     print(feat.shape)
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
torch.Size([16, 2048])

Get Swin-Transformer model with pre-trained weights and inference:

>>> from mmpretrain import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
'sea snake'
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.