Shortcuts

ImageToImageRetriever

class mmpretrain.models.retrievers.ImageToImageRetriever(image_encoder, prototype, head=None, pretrained=None, similarity_fn='cosine_similarity', train_cfg=None, data_preprocessor=None, topk=-1, init_cfg=None)[源代码]

Image To Image Retriever for supervised retrieval task.

参数:
  • image_encoder (Union[dict, List[dict]]) – Encoder for extracting features.

  • prototype (Union[DataLoader, dict, str, torch.Tensor]) –

    Database to be retrieved. The following four types are supported.

    • DataLoader: The original dataloader serves as the prototype.

    • dict: The configuration to construct Dataloader.

    • str: The path of the saved vector.

    • torch.Tensor: The saved tensor whose dimension should be dim.

  • head (dict, optional) – The head module to calculate loss from processed features. See mmpretrain.models.heads. Notice that if the head is not set, loss method cannot be used. Defaults to None.

  • similarity_fn (Union[str, Callable]) – The way that the similarity is calculated. If similarity is callable, it is used directly as the measure function. If it is a string, the appropriate method will be used. The larger the calculated value, the greater the similarity. Defaults to “cosine_similarity”.

  • train_cfg (dict, optional) –

    The training setting. The acceptable fields are:

    • augments (List[dict]): The batch augmentation methods to use. More details can be found in mmpretrain.model.utils.augment.

    Defaults to None.

  • data_preprocessor (dict, optional) – The config for preprocessing input data. If None or no specified type, it will use “ClsDataPreprocessor” as type. See ClsDataPreprocessor for more details. Defaults to None.

  • topk (int) – Return the topk of the retrieval result. -1 means return all. Defaults to -1.

  • init_cfg (dict, optional) – the config to control the initialization. Defaults to None.

dump_prototype(path)[源代码]

Save the features extracted from the prototype to specific path.

参数:

path (str) – Path to save feature.

extract_feat(inputs)[源代码]

Extract features from the input tensor with shape (N, C, …).

参数:

inputs (Tensor) – A batch of inputs. The shape of it should be (num_samples, num_channels, *img_shape).

返回:

The output of encoder.

返回类型:

Tensor

forward(inputs, data_samples=None, mode='tensor')[源代码]

The unified entry for a forward process in both training and test.

The method should accept three modes: “tensor”, “predict” and “loss”:

  • “tensor”: Forward the whole network and return tensor without any post-processing, same as a common nn.Module.

  • “predict”: Forward and return the predictions, which are fully processed to a list of DataSample.

  • “loss”: Forward and return a dict of losses according to the given inputs and data samples.

Note that this method doesn’t handle neither back propagation nor optimizer updating, which are done in the train_step().

参数:
  • inputs (torch.Tensor, tuple) – The input tensor with shape (N, C, …) in general.

  • data_samples (List[DataSample], optional) – The annotation data of every samples. It’s required if mode="loss". Defaults to None.

  • mode (str) – Return what kind of value. Defaults to ‘tensor’.

返回:

The return type depends on mode.

loss(inputs, data_samples)[源代码]

Calculate losses from a batch of inputs and data samples.

参数:
  • inputs (torch.Tensor) – The input tensor with shape (N, C, …) in general.

  • data_samples (List[DataSample]) – The annotation data of every samples.

返回:

a dictionary of loss components

返回类型:

dict[str, Tensor]

matching(inputs)[源代码]

Compare the prototype and calculate the similarity.

参数:

inputs (torch.Tensor) – The input tensor with shape (N, C).

返回:

a dictionary of score and prediction label based on fn.

返回类型:

dict

predict(inputs, data_samples=None, **kwargs)[源代码]

Predict results from the extracted features.

参数:
  • inputs (tuple) – The features extracted from the backbone.

  • data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.

  • **kwargs – Other keyword arguments accepted by the predict method of head.

返回:

the raw data_samples with

the predicted results

返回类型:

List[DataSample]

prepare_prototype()[源代码]

Used in meta testing. This function will be called before the meta testing. Obtain the vector based on the prototype.

  • torch.Tensor: The prototype vector is the prototype

  • str: The path of the extracted feature path, parse data structure,

    and generate the prototype feature vector set

  • Dataloader or config: Extract and save the feature vectors according

    to the dataloader

property similarity_fn

Returns a function that calculates the similarity.

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.