- class mmpretrain.models.retrievers.ImageToImageRetriever(image_encoder, prototype, head=None, pretrained=None, similarity_fn='cosine_similarity', train_cfg=None, data_preprocessor=None, topk=-1, init_cfg=None)¶
Image To Image Retriever for supervised retrieval task.
Database to be retrieved. The following four types are supported.
DataLoader: The original dataloader serves as the prototype.
dict: The configuration to construct Dataloader.
str: The path of the saved vector.
torch.Tensor: The saved tensor whose dimension should be dim.
similarity_fn (Union[str, Callable]) – The way that the similarity is calculated. If similarity is callable, it is used directly as the measure function. If it is a string, the appropriate method will be used. The larger the calculated value, the greater the similarity. Defaults to “cosine_similarity”.
train_cfg (dict, optional) –
The training setting. The acceptable fields are:
augments (List[dict]): The batch augmentation methods to use. More details can be found in
Defaults to None.
data_preprocessor (dict, optional) – The config for preprocessing input data. If None or no specified type, it will use “ClsDataPreprocessor” as type. See
ClsDataPreprocessorfor more details. Defaults to None.
topk (int) – Return the topk of the retrieval result. -1 means return all. Defaults to -1.
init_cfg (dict, optional) – the config to control the initialization. Defaults to None.
Save the features extracted from the prototype to specific path.
path (str) – Path to save feature.
Extract features from the input tensor with shape (N, C, …).
inputs (Tensor) – A batch of inputs. The shape of it should be
(num_samples, num_channels, *img_shape).
The output of encoder.
- Return type:
- forward(inputs, data_samples=None, mode='tensor')¶
The unified entry for a forward process in both training and test.
The method should accept three modes: “tensor”, “predict” and “loss”:
“tensor”: Forward the whole network and return tensor without any post-processing, same as a common nn.Module.
“predict”: Forward and return the predictions, which are fully processed to a list of
“loss”: Forward and return a dict of losses according to the given inputs and data samples.
Note that this method doesn’t handle neither back propagation nor optimizer updating, which are done in the
The return type depends on
mode="tensor", return a tensor.
mode="predict", return a list of
mode="loss", return a dict of tensor.
- loss(inputs, data_samples)¶
Calculate losses from a batch of inputs and data samples.
Compare the prototype and calculate the similarity.
- predict(inputs, data_samples=None, **kwargs)¶
Predict results from the extracted features.
- the raw data_samples with
the predicted results
- Return type:
Used in meta testing. This function will be called before the meta testing. Obtain the vector based on the prototype.
torch.Tensor: The prototype vector is the prototype
- str: The path of the extracted feature path, parse data structure,
and generate the prototype feature vector set
- Dataloader or config: Extract and save the feature vectors according
to the dataloader
- property similarity_fn¶
Returns a function that calculates the similarity.