mmpretrain.apis.feature_extractor 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional, Union

import torch
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import Compose, default_collate

from mmpretrain.registry import TRANSFORMS
from .base import BaseInferencer, InputType
from .model import list_models

[文档]class FeatureExtractor(BaseInferencer): """The inferencer for extract features. Args: model (BaseModel | str | Config): A model name or a path to the config file, or a :obj:`BaseModel` object. The model name can be found by ``FeatureExtractor.list_models()`` and you can also query it in :doc:`/modelzoo_statistics`. pretrained (str, optional): Path to the checkpoint. If None, it will try to find a pre-defined weight from the model you specified (only work if the ``model`` is a model name). Defaults to None. device (str, optional): Device to run inference. If None, the available device will be automatically used. Defaults to None. **kwargs: Other keyword arguments to initialize the model (only work if the ``model`` is a model name). Example: >>> from mmpretrain import FeatureExtractor >>> inferencer = FeatureExtractor('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3))) >>> feats = inferencer('demo/demo.JPEG', stage='backbone')[0] >>> for feat in feats: >>> print(feat.shape) torch.Size([256, 56, 56]) torch.Size([512, 28, 28]) torch.Size([1024, 14, 14]) torch.Size([2048, 7, 7]) """ # noqa: E501
[文档] def __call__(self, inputs: InputType, batch_size: int = 1, **kwargs) -> dict: """Call the inferencer. Args: inputs (str | array | list): The image path or array, or a list of images. batch_size (int): Batch size. Defaults to 1. **kwargs: Other keyword arguments accepted by the `extract_feat` method of the model. Returns: tensor | Tuple[tensor]: The extracted features. """ ori_inputs = self._inputs_to_list(inputs) inputs = self.preprocess(ori_inputs, batch_size=batch_size) preds = [] for data in inputs: preds.extend(self.forward(data, **kwargs)) return preds
@torch.no_grad() def forward(self, inputs: Union[dict, tuple], **kwargs): inputs = self.model.data_preprocessor(inputs, False)['inputs'] outputs = self.model.extract_feat(inputs, **kwargs) def scatter(feats, index): if isinstance(feats, torch.Tensor): return feats[index] else: # Sequence of tensor return type(feats)([scatter(item, index) for item in feats]) results = [] for i in range(inputs.shape[0]): results.append(scatter(outputs, i)) return results def _init_pipeline(self, cfg: Config) -> Callable: test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline from mmpretrain.datasets import remove_transform # Image loading is finished in `self.preprocess`. test_pipeline_cfg = remove_transform(test_pipeline_cfg, 'LoadImageFromFile') test_pipeline = Compose( [ for t in test_pipeline_cfg]) return test_pipeline def preprocess(self, inputs: List[InputType], batch_size: int = 1): def load_image(input_): img = imread(input_) if img is None: raise ValueError(f'Failed to read image {input_}.') return dict( img=img, img_shape=img.shape[:2], ori_shape=img.shape[:2], ) pipeline = Compose([load_image, self.pipeline]) chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) yield from map(default_collate, chunked_data) def visualize(self): raise NotImplementedError( "The FeatureExtractor doesn't support visualization.") def postprocess(self): raise NotImplementedError( "The FeatureExtractor doesn't need postprocessing.")
[文档] @staticmethod def list_models(pattern: Optional[str] = None): """List all available model names. Args: pattern (str | None): A wildcard pattern to match model names. Returns: List[str]: a list of model names. """ return list_models(pattern=pattern)
Read the Docs v: stable
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.