Shortcuts

MiniGPT4

class mmpretrain.models.multimodal.MiniGPT4(vision_encoder, q_former_model, lang_encoder, tokenizer, task='caption', freeze_vit=True, freeze_q_former=True, num_query_token=32, prompt_template={'en': '###Ask: {} ###Answer: ', 'zh': '###问:{} ###答:'}, raw_prompts={}, max_txt_len=32, end_sym='###', generation_cfg={}, data_preprocessor=None, init_cfg=None)[source]

The multi-modality model of MiniGPT-4.

The implementation of MiniGPT-4. Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py

Parameters:
  • vision_encoder (dict) – The config for vision encoder.

  • q_former_model (dict) – The config for Qformer.

  • lang_encoder (dict) – The config for language model.

  • tokenizer (dict) – The config for tokenizer.

  • task (str) – To define the task, which control the processing of text. Defaults to ‘caption’.

  • freeze_vit (bool) – Freeze the training of ViT. Defaults to True.

  • freeze_q_former (bool) – Freeze the training of Qformer. Defaults to True.

  • num_query_token (int) – Number of query tokens of Qformer. Defaults to 32.

  • prompt_template (dict) – Multi-language prompt template of the model. Defaults to dict([ (‘en’, ‘###Ask: {} ###Answer: ‘), (‘zh’, ‘###问:{} ###答:’)])

  • raw_prompts (dict) – Prompts for training. Defaults to dict().

  • max_txt_len (int) – Max token length while doing tokenization. Defaults to 32.

  • end_sym (str) – Ended symbol of the sequence. Defaults to ‘###’.

  • generation_cfg (dict) – The config of text generation. Defaults to dict().

  • data_preprocessor (BaseDataPreprocessor) – Used for pre-processing data sampled by dataloader to the format accepted by forward(). Defaults to None.

  • init_cfg (dict) – Initialization config dict. Defaults to None.

encode_img(images)[source]

The function to encode the images.

forward(images, data_samples=None, mode='predict', **kwargs)[source]

The unified entry for a forward process in both training and test. The method accepts the following modes:

  • “predict”: Forward and return a list of data samples contain the predict results.

Parameters:
  • images (torch.Tensor) – the preprocessed image tensor of shape (N, C, H, W).

  • data_samples (List[DataSample], optional) – The annotation data of every samples. Defaults to None.

  • mode (str) – Return what kind of value. Defaults to ‘predict’.

loss(images, data_samples=None)[source]

The forward function in training.

Parameters:
  • inputs (List[torch.Tensor]) – The input images.

  • data_samples (List[DataSample]) – All elements required during the forward function.

Returns:

A dictionary of loss components.

Return type:

Dict[str, torch.Tensor]

post_process(outputs, data_samples)[source]

Perform post process for outputs for different task.

Parameters:
  • outputs (torch.Tensor) – The generated outputs.

  • data_samples (List[DataSample], optional) – The annotation data of every samples.

Returns:

Return list of data samples.

Return type:

List[DataSample]

prompt_wrap(img_embeds, atts_img, prompt)[source]

The function to wrap the image and prompt.

Make sure that len(prompt) == img_embeds.shape[0].

Parameters:
  • img_embeds (torch.Tensor) – The embedding of the input images.

  • atts_img (torch.Tensor) – Attention map of the image embeddings.

  • prompt (List[str]) – The prompt of the batch data.

Returns:

The embedding and attention map.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.