Shortcuts

BaseRetriever

class mmpretrain.models.retrievers.BaseRetriever(prototype=None, data_preprocessor=None, init_cfg=None)[源代码]

Base class for retriever.

参数:
  • init_cfg (dict, optional) – Initialization config dict. Defaults to None.

  • data_preprocessor (dict, optional) – The config for preprocessing input data. If None, it will use “BaseDataPreprocessor” as type, see mmengine.model.BaseDataPreprocessor for more details. Defaults to None.

  • prototype (Union[DataLoader, dict, str, torch.Tensor]) –

    Database to be retrieved. The following four types are supported.

    • DataLoader: The original dataloader serves as the prototype.

    • dict: The configuration to construct Dataloader.

    • str: The path of the saved vector.

    • torch.Tensor: The saved tensor whose dimension should be dim.

prototype

Database to be retrieved. The following four types are supported.

  • DataLoader: The original dataloader serves as the prototype.

  • dict: The configuration to construct Dataloader.

  • str: The path of the saved vector.

  • torch.Tensor: The saved tensor whose dimension should be dim.

Type:

Union[DataLoader, dict, str, torch.Tensor]

data_preprocessor

An extra data pre-processing module, which processes data from dataloader to the format accepted by forward().

Type:

mmengine.model.BaseDataPreprocessor

dump_prototype(path)[源代码]

Save the features extracted from the prototype to the specific path.

参数:

path (str) – Path to save feature.

extract_feat(inputs)[源代码]

Extract features from the input tensor with shape (N, C, …).

The sub-classes are recommended to implement this method to extract features from backbone and neck.

参数:

inputs (Tensor) – A batch of inputs. The shape of it should be (num_samples, num_channels, *img_shape).

abstract forward(inputs, data_samples=None, mode='loss')[源代码]

The unified entry for a forward process in both training and test.

The method should accept three modes: “tensor”, “predict” and “loss”:

  • “tensor”: Forward the whole network and return tensor without any post-processing, same as a common nn.Module.

  • “predict”: Forward and return the predictions, which are fully processed to a list of DataSample.

  • “loss”: Forward and return a dict of losses according to the given inputs and data samples.

Note that this method doesn’t handle neither back propagation nor optimizer updating, which are done in the train_step().

参数:
  • inputs (torch.Tensor, tuple) – The input tensor with shape (N, C, …) in general.

  • data_samples (List[DataSample], optional) – The annotation data of every samples. It’s required if mode="loss". Defaults to None.

  • mode (str) – Return what kind of value. Defaults to ‘tensor’.

返回:

The return type depends on mode.

loss(inputs, data_samples)[源代码]

Calculate losses from a batch of inputs and data samples.

参数:
  • inputs (torch.Tensor) – The input tensor with shape (N, C, …) in general.

  • data_samples (List[DataSample]) – The annotation data of every samples.

返回:

a dictionary of loss components

返回类型:

dict[str, Tensor]

matching(inputs)[源代码]

Compare the prototype and calculate the similarity.

参数:

inputs (torch.Tensor) – The input tensor with shape (N, C).

predict(inputs, data_samples=None, **kwargs)[源代码]

Predict results from the extracted features.

参数:
  • inputs (tuple) – The features extracted from the backbone.

  • data_samples (List[BaseDataElement], optional) – The annotation data of every samples. Defaults to None.

  • **kwargs – Other keyword arguments accepted by the predict method of head.

prepare_prototype()[源代码]

Preprocessing the prototype before predict.

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.