mmpretrain.apis.get_model¶
- mmpretrain.apis.get_model(model, pretrained=False, device=None, device_map=None, offload_folder=None, url_mapping=None, **kwargs)[source]¶
Get a pre-defined model or create a model from config.
- Parameters:
model (str | Config) – The name of model, the config file path or a config instance.
pretrained (bool | str) – When use name to specify model, you can use
True
to load the pre-defined pretrained weights. And you can also use a string to specify the path or link of weights to load. Defaults to False.device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.
device_map (str | dict | None) – A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. You can use device_map=”auto” to automatically generate the device map. Defaults to None.
offload_folder (str | None) – If the device_map contains any value “disk”, the folder where we will offload weights.
url_mapping (Tuple[str, str], optional) – The mapping of pretrained checkpoint link. For example, load checkpoint from a local dir instead of download by
('https://.*/', './checkpoint')
. Defaults to None.**kwargs – Other keyword arguments of the model config.
- Returns:
The result model.
- Return type:
Examples
Get a ResNet-50 model and extract images feature:
>>> import torch >>> from mmpretrain import get_model >>> inputs = torch.rand(16, 3, 224, 224) >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) >>> feats = model.extract_feat(inputs) >>> for feat in feats: ... print(feat.shape) torch.Size([16, 256]) torch.Size([16, 512]) torch.Size([16, 1024]) torch.Size([16, 2048])
Get Swin-Transformer model with pre-trained weights and inference:
>>> from mmpretrain import get_model, inference_model >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) >>> result = inference_model(model, 'demo/demo.JPEG') >>> print(result['pred_class']) 'sea snake'