MiniGPT4¶
- class mmpretrain.models.multimodal.MiniGPT4(vision_encoder, q_former_model, lang_encoder, tokenizer, task='caption', freeze_vit=True, freeze_q_former=True, num_query_token=32, prompt_template={'en': '###Ask: {} ###Answer: ', 'zh': '###问:{} ###答:'}, raw_prompts={}, max_txt_len=32, end_sym='###', generation_cfg={}, data_preprocessor=None, init_cfg=None)[源代码]¶
The multi-modality model of MiniGPT-4.
The implementation of MiniGPT-4. Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py
- 参数:
vision_encoder (dict) – The config for vision encoder.
q_former_model (dict) – The config for Qformer.
lang_encoder (dict) – The config for language model.
tokenizer (dict) – The config for tokenizer.
task (str) – To define the task, which control the processing of text. Defaults to ‘caption’.
freeze_vit (bool) – Freeze the training of ViT. Defaults to True.
freeze_q_former (bool) – Freeze the training of Qformer. Defaults to True.
num_query_token (int) – Number of query tokens of Qformer. Defaults to 32.
prompt_template (dict) – Multi-language prompt template of the model. Defaults to dict([ (‘en’, ‘###Ask: {} ###Answer: ‘), (‘zh’, ‘###问:{} ###答:’)])
raw_prompts (dict) – Prompts for training. Defaults to dict().
max_txt_len (int) – Max token length while doing tokenization. Defaults to 32.
end_sym (str) – Ended symbol of the sequence. Defaults to ‘###’.
generation_cfg (dict) – The config of text generation. Defaults to dict().
data_preprocessor (
BaseDataPreprocessor
) – Used for pre-processing data sampled by dataloader to the format accepted byforward()
. Defaults to None.init_cfg (dict) – Initialization config dict. Defaults to None.
- forward(images, data_samples=None, mode='predict', **kwargs)[源代码]¶
The unified entry for a forward process in both training and test. The method accepts the following modes:
“predict”: Forward and return a list of data samples contain the predict results.
- 参数:
images (torch.Tensor) – the preprocessed image tensor of shape
(N, C, H, W)
.data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.
mode (str) – Return what kind of value. Defaults to ‘predict’.
- loss(images, data_samples=None)[源代码]¶
The forward function in training.
- 参数:
inputs (List[torch.Tensor]) – The input images.
data_samples (List[DataSample]) – All elements required during the forward function.
- 返回:
A dictionary of loss components.
- 返回类型:
Dict[str, torch.Tensor]
- post_process(outputs, data_samples)[源代码]¶
Perform post process for outputs for different task.
- 参数:
outputs (torch.Tensor) – The generated outputs.
data_samples (List[DataSample], optional) – The annotation data of every samples.
- 返回:
Return list of data samples.
- 返回类型:
List[DataSample]
- prompt_wrap(img_embeds, atts_img, prompt)[源代码]¶
The function to wrap the image and prompt.
Make sure that len(prompt) == img_embeds.shape[0].
- 参数:
img_embeds (torch.Tensor) – The embedding of the input images.
atts_img (torch.Tensor) – Attention map of the image embeddings.
prompt (List[str]) – The prompt of the batch data.
- 返回:
The embedding and attention map.
- 返回类型:
Tuple[torch.Tensor, torch.Tensor]