OFA¶
- class mmpretrain.models.multimodal.OFA(encoder_cfg, decoder_cfg, vocab_size, embedding_dim, tokenizer, task, prompt=None, ans2label=None, generation_cfg={}, data_preprocessor=None, init_cfg=None)[source]¶
The OFA model for multiple tasks.
- Parameters:
encoder_cfg (dict) – The config of the encoder, accept the keyword arguments of
OFAEncoder
.decoder_cfg (dict) – The config of the decoder, accept the keyword arguments of
OFADecoder
.vocab_size (int) – The size of the vocabulary.
embedding_dim (int) – The embedding dimensions of both the encoder and the decoder.
tokenizer (dict | PreTrainedTokenizer) – The tokenizer to encode the text.
task (str) – The task name, supported tasks are “caption”, “vqa” and “refcoco”.
prompt (str, optional) –
The prompt template for the following tasks, If None, use default prompt:
caption: ‘ what does the image describe?’
refcoco: ‘ which region does the text ” {} ” describe?’
Defaults to None
ans2label (str | Sequence | None) – The answer to label mapping for the vqa task. If a string, it should be a pickle or json file. The sequence constrains the output answers. Defaults to None, which means no constraint.
generation_cfg (dict) – The extra generation config, accept the keyword arguments of
GenerationConfig
. Defaults to an empty dict.data_preprocessor (dict, optional) – The config for preprocessing input data. If None or no specified type, it will use “MultiModalDataPreprocessor” as type. See :class: MultiModalDataPreprocessor for more details. Defaults to None.
init_cfg (dict, optional) – The initialization config. Defaults to None.
- forward(images, data_samples=None, mode='predict', **kwargs)[source]¶
The unified entry for a forward process in both training and test. The method accepts the following modes:
“predict”: Forward and return a list of data samples contain the predict results.
- Parameters:
images (torch.Tensor) – the preprocessed image tensor of shape
(N, C, H, W)
.data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.
mode (str) – Return what kind of value. Defaults to ‘predict’.