class mmpretrain.models.retrievers.ImageToImageRetriever(image_encoder, prototype, head=None, pretrained=None, similarity_fn='cosine_similarity', train_cfg=None, data_preprocessor=None, topk=-1, init_cfg=None)[source]

Image To Image Retriever for supervised retrieval task.

  • image_encoder (Union[dict, List[dict]]) – Encoder for extracting features.

  • prototype (Union[DataLoader, dict, str, torch.Tensor]) –

    Database to be retrieved. The following four types are supported.

    • DataLoader: The original dataloader serves as the prototype.

    • dict: The configuration to construct Dataloader.

    • str: The path of the saved vector.

    • torch.Tensor: The saved tensor whose dimension should be dim.

  • head (dict, optional) – The head module to calculate loss from processed features. See mmpretrain.models.heads. Notice that if the head is not set, loss method cannot be used. Defaults to None.

  • similarity_fn (Union[str, Callable]) – The way that the similarity is calculated. If similarity is callable, it is used directly as the measure function. If it is a string, the appropriate method will be used. The larger the calculated value, the greater the similarity. Defaults to “cosine_similarity”.

  • train_cfg (dict, optional) –

    The training setting. The acceptable fields are:

    • augments (List[dict]): The batch augmentation methods to use. More details can be found in mmpretrain.model.utils.augment.

    Defaults to None.

  • data_preprocessor (dict, optional) – The config for preprocessing input data. If None or no specified type, it will use “ClsDataPreprocessor” as type. See ClsDataPreprocessor for more details. Defaults to None.

  • topk (int) – Return the topk of the retrieval result. -1 means return all. Defaults to -1.

  • init_cfg (dict, optional) – the config to control the initialization. Defaults to None.


Save the features extracted from the prototype to specific path.


path (str) – Path to save feature.


Extract features from the input tensor with shape (N, C, …).


inputs (Tensor) – A batch of inputs. The shape of it should be (num_samples, num_channels, *img_shape).


The output of encoder.

Return type:


forward(inputs, data_samples=None, mode='tensor')[source]

The unified entry for a forward process in both training and test.

The method should accept three modes: “tensor”, “predict” and “loss”:

  • “tensor”: Forward the whole network and return tensor without any post-processing, same as a common nn.Module.

  • “predict”: Forward and return the predictions, which are fully processed to a list of DataSample.

  • “loss”: Forward and return a dict of losses according to the given inputs and data samples.

Note that this method doesn’t handle neither back propagation nor optimizer updating, which are done in the train_step().

  • inputs (torch.Tensor, tuple) – The input tensor with shape (N, C, …) in general.

  • data_samples (List[DataSample], optional) – The annotation data of every samples. It’s required if mode="loss". Defaults to None.

  • mode (str) – Return what kind of value. Defaults to ‘tensor’.


The return type depends on mode.

loss(inputs, data_samples)[source]

Calculate losses from a batch of inputs and data samples.

  • inputs (torch.Tensor) – The input tensor with shape (N, C, …) in general.

  • data_samples (List[DataSample]) – The annotation data of every samples.


a dictionary of loss components

Return type:

dict[str, Tensor]


Compare the prototype and calculate the similarity.


inputs (torch.Tensor) – The input tensor with shape (N, C).


a dictionary of score and prediction label based on fn.

Return type:


predict(inputs, data_samples=None, **kwargs)[source]

Predict results from the extracted features.

  • inputs (tuple) – The features extracted from the backbone.

  • data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.

  • **kwargs – Other keyword arguments accepted by the predict method of head.


the raw data_samples with

the predicted results

Return type:



Used in meta testing. This function will be called before the meta testing. Obtain the vector based on the prototype.

  • torch.Tensor: The prototype vector is the prototype

  • str: The path of the extracted feature path, parse data structure,

    and generate the prototype feature vector set

  • Dataloader or config: Extract and save the feature vectors according

    to the dataloader

property similarity_fn

Returns a function that calculates the similarity.

Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.