mmcls.models.backbones.timm_backbone 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

from mmengine.logging import MMLogger

from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

def print_timm_feature_info(feature_info):
    """Print feature_info of timm backbone to help development and debug.

        feature_info (list[dict] | timm.models.features.FeatureInfo | None):
            feature_info of timm backbone.
    logger = MMLogger.get_current_instance()
    if feature_info is None:
        logger.warning('This backbone does not have feature_info')
    elif isinstance(feature_info, list):
        for feat_idx, each_info in enumerate(feature_info):
  'backbone feature_info[{feat_idx}]: {each_info}')
  'backbone out_indices: {feature_info.out_indices}')
  'backbone out_channels: {feature_info.channels()}')
  'backbone out_strides: {feature_info.reduction()}')
        except AttributeError:
            logger.warning('Unexpected format of backbone feature_info')

[文档]@MODELS.register_module() class TIMMBackbone(BaseBackbone): """Wrapper to use backbones from timm library. More details can be found in `timm <>`_. See especially the document for `feature extraction <>`_. Args: model_name (str): Name of timm model to instantiate. features_only (bool): Whether to extract feature pyramid (multi-scale feature maps from the deepest layer at each stride). For Vision Transformer models that do not support this argument, set this False. Defaults to False. pretrained (bool): Whether to load pretrained weights. Defaults to False. checkpoint_path (str): Path of checkpoint to load at the last of ``timm.create_model``. Defaults to empty string, which means not loading. in_channels (int): Number of input image channels. Defaults to 3. init_cfg (dict or list[dict], optional): Initialization config dict of OpenMMLab projects. Defaults to None. **kwargs: Other timm & model specific arguments. """ def __init__(self, model_name, features_only=False, pretrained=False, checkpoint_path='', in_channels=3, init_cfg=None, **kwargs): try: import timm except ImportError: raise ImportError( 'Failed to import timm. Please run "pip install timm".') if not isinstance(pretrained, bool): raise TypeError('pretrained must be bool, not str for model path') if features_only and checkpoint_path: warnings.warn( 'Using both features_only and checkpoint_path will cause error' ' in timm. See ' '') super(TIMMBackbone, self).__init__(init_cfg) if 'norm_layer' in kwargs: kwargs['norm_layer'] = MODELS.get(kwargs['norm_layer']) self.timm_model = timm.create_model( model_name=model_name, features_only=features_only, pretrained=pretrained, in_chans=in_channels, checkpoint_path=checkpoint_path, **kwargs) # reset classifier if hasattr(self.timm_model, 'reset_classifier'): self.timm_model.reset_classifier(0, '') # Hack to use pretrained weights from timm if pretrained or checkpoint_path: self._is_init = True feature_info = getattr(self.timm_model, 'feature_info', None) print_timm_feature_info(feature_info) def forward(self, x): features = self.timm_model(x) if isinstance(features, (list, tuple)): features = tuple(features) else: features = (features, ) return features
Read the Docs v: mmcls-1.x
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.