Shortcuts

TextToImageRetrievalInferencer

class mmpretrain.apis.TextToImageRetrievalInferencer(model, prototype, prototype_cache=None, fast_match=True, prepare_batch_size=8, pretrained=True, device=None, **kwargs)[源代码]

The inferencer for text to image retrieval.

参数:
  • model (BaseModel | str | Config) – A model name or a path to the config file, or a BaseModel object. The model name can be found by TextToImageRetrievalInferencer.list_models() and you can also query it in 模型库统计.

  • prototype (str | list | dict | DataLoader | BaseDataset) –

    The images to be retrieved. It can be the following types:

    • str: The directory of the the images.

    • list: A list of path of the images.

    • dict: A config dict of the a prototype dataset.

    • BaseDataset: A prototype dataset.

    • DataLoader: A data loader to load the prototype data.

  • prototype_cache (str, optional) – The path of the generated prototype features. If exists, directly load the cache instead of re-generate the prototype features. If not exists, save the generated features to the path. Defaults to None.

  • fast_match (bool) – Some algorithms will record extra image features for further matching, which may consume large memory, set True to avoid this behavior. Defaults to True.

  • pretrained (str, optional) – Path to the checkpoint. If None, it will try to find a pre-defined weight from the model you specified (only work if the model is a model name). Defaults to None.

  • device (str, optional) – Device to run inference. If None, the available device will be automatically used. Defaults to None.

  • **kwargs – Other keyword arguments to initialize the model (only work if the model is a model name).

示例

>>> from mmpretrain import TextToImageRetrievalInferencer
>>> inferencer = TextToImageRetrievalInferencer(
...     'blip-base_3rdparty_retrieval',
...     prototype='./demo/',
...     prototype_cache='t2i_retri.pth')
>>> inferencer('A cat and a dog.')[0]
{'match_score': tensor(0.3855, device='cuda:0'),
 'sample_idx': 1,
 'sample': {'img_path': './demo/cat-dog.png'}}
__call__(inputs, return_datasamples=False, batch_size=1, **kwargs)[源代码]

Call the inferencer.

参数:
  • inputs (str | array | list) – The image path or array, or a list of images.

  • return_datasamples (bool) – Whether to return results as DataSample. Defaults to False.

  • batch_size (int) – Batch size. Defaults to 1.

  • resize (int, optional) – Resize the long edge of the image to the specified length before visualization. Defaults to None.

  • draw_score (bool) – Whether to draw the match scores. Defaults to True.

  • show (bool) – Whether to display the visualization result in a window. Defaults to False.

  • wait_time (float) – The display time (s). Defaults to 0, which means “forever”.

  • show_dir (str, optional) – If not None, save the visualization results in the specified directory. Defaults to None.

返回:

The inference results.

返回类型:

list

forward(data, **kwargs)[源代码]

Feed the inputs to the model.

static list_models(pattern=None)[源代码]

List all available model names.

参数:

pattern (str | None) – A wildcard pattern to match model names.

返回:

a list of model names.

返回类型:

List[str]

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.