mmpretrain.models.classifiers.base 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Optional, Sequence

import torch
from mmengine.model import BaseModel
from mmengine.structures import BaseDataElement

[文档]class BaseClassifier(BaseModel, metaclass=ABCMeta): """Base class for classifiers. Args: init_cfg (dict, optional): Initialization config dict. Defaults to None. data_preprocessor (dict, optional): The config for preprocessing input data. If None, it will use "BaseDataPreprocessor" as type, see :class:`mmengine.model.BaseDataPreprocessor` for more details. Defaults to None. Attributes: init_cfg (dict): Initialization config dict. data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An extra data pre-processing module, which processes data from dataloader to the format accepted by :meth:`forward`. """ def __init__(self, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None): super(BaseClassifier, self).__init__( init_cfg=init_cfg, data_preprocessor=data_preprocessor) @property def with_neck(self) -> bool: """Whether the classifier has a neck.""" return hasattr(self, 'neck') and self.neck is not None @property def with_head(self) -> bool: """Whether the classifier has a head.""" return hasattr(self, 'head') and self.head is not None
[文档] @abstractmethod def forward(self, inputs: torch.Tensor, data_samples: Optional[List[BaseDataElement]] = None, mode: str = 'tensor'): """The unified entry for a forward process in both training and test. The method should accept three modes: "tensor", "predict" and "loss": - "tensor": Forward the whole network and return tensor or tuple of tensor without any post-processing, same as a common nn.Module. - "predict": Forward and return the predictions, which are fully processed to a list of :obj:`BaseDataElement`. - "loss": Forward and return a dict of losses according to the given inputs and data samples. Note that this method doesn't handle neither back propagation nor optimizer updating, which are done in the :meth:`train_step`. Args: inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. data_samples (List[BaseDataElement], optional): The annotation data of every samples. It's required if ``mode="loss"``. Defaults to None. mode (str): Return what kind of value. Defaults to 'tensor'. Returns: The return type depends on ``mode``. - If ``mode="tensor"``, return a tensor or a tuple of tensor. - If ``mode="predict"``, return a list of :obj:`mmengine.BaseDataElement`. - If ``mode="loss"``, return a dict of tensor. """ pass
[文档] def extract_feat(self, inputs: torch.Tensor): """Extract features from the input tensor with shape (N, C, ...). The sub-classes are recommended to implement this method to extract features from backbone and neck. Args: inputs (Tensor): A batch of inputs. The shape of it should be ``(num_samples, num_channels, *img_shape)``. """ raise NotImplementedError
[文档] def extract_feats(self, multi_inputs: Sequence[torch.Tensor], **kwargs) -> list: """Extract features from a sequence of input tensor. Args: multi_inputs (Sequence[torch.Tensor]): A sequence of input tensor. It can be used in augmented inference. **kwargs: Other keyword arguments accepted by :meth:`extract_feat`. Returns: list: Features of every input tensor. """ assert isinstance(multi_inputs, Sequence), \ '`extract_feats` is used for a sequence of inputs tensor. If you '\ 'want to extract on single inputs tensor, use `extract_feat`.' return [self.extract_feat(inputs, **kwargs) for inputs in multi_inputs]
Read the Docs v: dev
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.