Shortcuts

mmcls.apis.model 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import fnmatch
import os.path as osp
import warnings
from os import PathLike
from pathlib import Path
from typing import List, Union

from mmengine.config import Config
from modelindex.load_model_index import load
from modelindex.models.Model import Model


class ModelHub:
    """A hub to host the meta information of all pre-defined models."""
    _models_dict = {}
    __mmcls_registered = False

    @classmethod
    def register_model_index(cls,
                             model_index_path: Union[str, PathLike],
                             config_prefix: Union[str, PathLike, None] = None):
        """Parse the model-index file and register all models.

        Args:
            model_index_path (str | PathLike): The path of the model-index
                file.
            config_prefix (str | PathLike | None): The prefix of all config
                file paths in the model-index file.
        """
        model_index = load(str(model_index_path))
        model_index.build_models_with_collections()

        for metainfo in model_index.models:
            model_name = metainfo.name.lower()
            if metainfo.name in cls._models_dict:
                raise ValueError(
                    'The model name {} is conflict in {} and {}.'.format(
                        model_name, osp.abspath(metainfo.filepath),
                        osp.abspath(cls._models_dict[model_name].filepath)))
            metainfo.config = cls._expand_config_path(metainfo, config_prefix)
            cls._models_dict[model_name] = metainfo

    @classmethod
    def get(cls, model_name):
        """Get the model's metainfo by the model name.

        Args:
            model_name (str): The name of model.

        Returns:
            modelindex.models.Model: The metainfo of the specified model.
        """
        cls._register_mmcls_models()
        # lazy load config
        metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
        if metainfo is None:
            raise ValueError(f'Failed to find model {model_name}.')
        if isinstance(metainfo.config, str):
            metainfo.config = Config.fromfile(metainfo.config)
        return metainfo

    @staticmethod
    def _expand_config_path(metainfo: Model,
                            config_prefix: Union[str, PathLike] = None):
        if config_prefix is None:
            config_prefix = osp.dirname(metainfo.filepath)

        if metainfo.config is None or osp.isabs(metainfo.config):
            config_path: str = metainfo.config
        else:
            config_path = osp.abspath(osp.join(config_prefix, metainfo.config))

        return config_path

    @classmethod
    def _register_mmcls_models(cls):
        # register models in mmcls
        if not cls.__mmcls_registered:
            from mmengine.utils import get_installed_path
            mmcls_root = Path(get_installed_path('mmcls'))
            model_index_path = mmcls_root / '.mim' / 'model-index.yml'
            ModelHub.register_model_index(
                model_index_path, config_prefix=mmcls_root / '.mim')
            cls.__mmcls_registered = True

    @classmethod
    def has(cls, model_name):
        """Whether a model name is in the ModelHub."""
        return model_name in cls._models_dict


[文档]def init_model(config, checkpoint=None, device=None, **kwargs): """Initialize a classifier from config file. Args: config (str | :obj:`mmengine.Config`): Config file path or the config object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. device (str | torch.device | None): Transfer the model to the target device. Defaults to None. **kwargs: Other keyword arguments of the model config. Returns: nn.Module: The constructed model. """ if isinstance(config, (str, PathLike)): config = Config.fromfile(config) elif isinstance(config, Config): config = copy.deepcopy(config) else: raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') if kwargs: config.merge_from_dict({'model': kwargs}) config.model.setdefault('data_preprocessor', config.get('data_preprocessor', None)) from mmcls.registry import MODELS model = MODELS.build(config.model) if checkpoint is not None: # Mapping the weights to GPU may cause unexpected video memory leak # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 from mmengine.runner import load_checkpoint checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') if not model.with_head: # Don't set CLASSES if the model is headless. pass elif 'dataset_meta' in checkpoint.get('meta', {}): # mmcls 1.x model.CLASSES = checkpoint['meta']['dataset_meta'].get('classes') elif 'CLASSES' in checkpoint.get('meta', {}): # mmcls < 1.x model.CLASSES = checkpoint['meta']['CLASSES'] else: from mmcls.datasets.categories import IMAGENET_CATEGORIES warnings.simplefilter('once') warnings.warn('Class names are not saved in the checkpoint\'s ' 'meta data, use imagenet by default.') model.CLASSES = IMAGENET_CATEGORIES model.cfg = config # save the config in the model for convenience model.to(device) model.eval() return model
[文档]def get_model(model_name, pretrained=False, device=None, **kwargs): """Get a pre-defined model by the name of model. Args: model_name (str): The name of model. pretrained (bool | str): If True, load the pre-defined pretrained weights. If a string, load the weights from it. Defaults to False. device (str | torch.device | None): Transfer the model to the target device. Defaults to None. **kwargs: Other keyword arguments of the model config. Returns: mmengine.model.BaseModel: The result model. Examples: Get a ResNet-50 model and extract images feature: >>> import torch >>> from mmcls import get_model >>> inputs = torch.rand(16, 3, 224, 224) >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) >>> feats = model.extract_feat(inputs) >>> for feat in feats: ... print(feat.shape) torch.Size([16, 256]) torch.Size([16, 512]) torch.Size([16, 1024]) torch.Size([16, 2048]) Get Swin-Transformer model with pre-trained weights and inference: >>> from mmcls import get_model, inference_model >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) >>> result = inference_model(model, 'demo/demo.JPEG') >>> print(result['pred_class']) 'sea snake' """ # noqa: E501 metainfo = ModelHub.get(model_name) if isinstance(pretrained, str): ckpt = pretrained elif pretrained: if metainfo.weights is None: raise ValueError( f"The model {model_name} doesn't have pretrained weights.") ckpt = metainfo.weights else: ckpt = None model = init_model(metainfo.config, ckpt, device=device, **kwargs) return model
[文档]def list_models(pattern=None) -> List[str]: """List all models available in MMClassification. Args: pattern (str | None): A wildcard pattern to match model names. Returns: List[str]: a list of model names. Examples: List all models: >>> from mmcls import list_models >>> print(list_models()) List ResNet-50 models on ImageNet-1k dataset: >>> from mmcls import list_models >>> print(list_models('resnet*in1k')) ['resnet50_8xb32_in1k', 'resnet50_8xb32-fp16_in1k', 'resnet50_8xb256-rsb-a1-600e_in1k', 'resnet50_8xb256-rsb-a2-300e_in1k', 'resnet50_8xb256-rsb-a3-100e_in1k'] """ ModelHub._register_mmcls_models() if pattern is None: return sorted(list(ModelHub._models_dict.keys())) # Always match keys with any postfix. matches = fnmatch.filter(ModelHub._models_dict.keys(), pattern + '*') return matches
Read the Docs v: mmcls-1.x
Versions
latest
stable
mmcls-1.x
mmcls-0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.