mmcls.apis.get_model¶
- mmcls.apis.get_model(model_name, pretrained=False, device=None, **kwargs)[源代码]¶
Get a pre-defined model by the name of model.
- 参数
model_name (str) – The name of model.
pretrained (bool | str) – If True, load the pre-defined pretrained weights. If a string, load the weights from it. Defaults to False.
device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.
**kwargs – Other keyword arguments of the model config.
- 返回
The result model.
- 返回类型
使用示例
Get a ResNet-50 model and extract images feature:
>>> import torch >>> from mmcls import get_model >>> inputs = torch.rand(16, 3, 224, 224) >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) >>> feats = model.extract_feat(inputs) >>> for feat in feats: ... print(feat.shape) torch.Size([16, 256]) torch.Size([16, 512]) torch.Size([16, 1024]) torch.Size([16, 2048])
Get Swin-Transformer model with pre-trained weights and inference:
>>> from mmcls import get_model, inference_model >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) >>> result = inference_model(model, 'demo/demo.JPEG') >>> print(result['pred_class']) 'sea snake'