ImageToImageRetriever¶
- class mmpretrain.models.retrievers.ImageToImageRetriever(image_encoder, prototype, head=None, pretrained=None, similarity_fn='cosine_similarity', train_cfg=None, data_preprocessor=None, topk=-1, init_cfg=None)[source]¶
Image To Image Retriever for supervised retrieval task.
- Parameters:
image_encoder (Union[dict, List[dict]]) – Encoder for extracting features.
prototype (Union[DataLoader, dict, str, torch.Tensor]) –
Database to be retrieved. The following four types are supported.
DataLoader: The original dataloader serves as the prototype.
dict: The configuration to construct Dataloader.
str: The path of the saved vector.
torch.Tensor: The saved tensor whose dimension should be dim.
head (dict, optional) – The head module to calculate loss from processed features. See
mmpretrain.models.heads
. Notice that if the head is not set, loss method cannot be used. Defaults to None.similarity_fn (Union[str, Callable]) – The way that the similarity is calculated. If similarity is callable, it is used directly as the measure function. If it is a string, the appropriate method will be used. The larger the calculated value, the greater the similarity. Defaults to “cosine_similarity”.
train_cfg (dict, optional) –
The training setting. The acceptable fields are:
augments (List[dict]): The batch augmentation methods to use. More details can be found in
mmpretrain.model.utils.augment
.
Defaults to None.
data_preprocessor (dict, optional) – The config for preprocessing input data. If None or no specified type, it will use “ClsDataPreprocessor” as type. See
ClsDataPreprocessor
for more details. Defaults to None.topk (int) – Return the topk of the retrieval result. -1 means return all. Defaults to -1.
init_cfg (dict, optional) – the config to control the initialization. Defaults to None.
- dump_prototype(path)[source]¶
Save the features extracted from the prototype to specific path.
- Parameters:
path (str) – Path to save feature.
- extract_feat(inputs)[source]¶
Extract features from the input tensor with shape (N, C, …).
- Parameters:
inputs (Tensor) – A batch of inputs. The shape of it should be
(num_samples, num_channels, *img_shape)
.- Returns:
The output of encoder.
- Return type:
Tensor
- forward(inputs, data_samples=None, mode='tensor')[source]¶
The unified entry for a forward process in both training and test.
The method should accept three modes: “tensor”, “predict” and “loss”:
“tensor”: Forward the whole network and return tensor without any post-processing, same as a common nn.Module.
“predict”: Forward and return the predictions, which are fully processed to a list of
DataSample
.“loss”: Forward and return a dict of losses according to the given inputs and data samples.
Note that this method doesn’t handle neither back propagation nor optimizer updating, which are done in the
train_step()
.- Parameters:
inputs (torch.Tensor, tuple) – The input tensor with shape (N, C, …) in general.
data_samples (List[DataSample], optional) – The annotation data of every samples. It’s required if
mode="loss"
. Defaults to None.mode (str) – Return what kind of value. Defaults to ‘tensor’.
- Returns:
The return type depends on
mode
.If
mode="tensor"
, return a tensor.If
mode="predict"
, return a list ofmmpretrain.structures.DataSample
.If
mode="loss"
, return a dict of tensor.
- loss(inputs, data_samples)[source]¶
Calculate losses from a batch of inputs and data samples.
- Parameters:
inputs (torch.Tensor) – The input tensor with shape (N, C, …) in general.
data_samples (List[DataSample]) – The annotation data of every samples.
- Returns:
a dictionary of loss components
- Return type:
- matching(inputs)[source]¶
Compare the prototype and calculate the similarity.
- Parameters:
inputs (torch.Tensor) – The input tensor with shape (N, C).
- Returns:
a dictionary of score and prediction label based on fn.
- Return type:
- predict(inputs, data_samples=None, **kwargs)[source]¶
Predict results from the extracted features.
- Parameters:
inputs (tuple) – The features extracted from the backbone.
data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.
**kwargs – Other keyword arguments accepted by the
predict
method ofhead
.
- Returns:
- the raw data_samples with
the predicted results
- Return type:
List[DataSample]
- prepare_prototype()[source]¶
Used in meta testing. This function will be called before the meta testing. Obtain the vector based on the prototype.
torch.Tensor: The prototype vector is the prototype
- str: The path of the extracted feature path, parse data structure,
and generate the prototype feature vector set
- Dataloader or config: Extract and save the feature vectors according
to the dataloader
- property similarity_fn¶
Returns a function that calculates the similarity.