Shortcuts

Source code for mmpretrain.models.multimodal.blip.blip_vqa

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union

import torch
from mmengine.model import BaseModel

from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample


[docs]@MODELS.register_module() class BlipVQA(BaseModel): """BLIP VQA. Args: tokenizer: (dict): The config for tokenizer. vision_backbone (dict): Encoder for extracting image features. multimodal_backbone (dict): Backbone for extracting multi-modal features. We apply this part as VQA fusion module. head (dict): The head module to calculate loss from processed features. data_preprocessor (Optional[dict]): The config for preprocessing input data. If None or no specified type, it will use `MutimodalDataPreprocessor` as type. See :class:`MutimodalDataPreprocessor` for more details. Defaults to None. init_cfg (Optional[dict]): the config to control the initialization. Defaults to None. """ def __init__(self, tokenizer: dict, vision_backbone: dict, multimodal_backbone: dict, head: dict, data_preprocessor: Optional[dict] = None, init_cfg: Optional[dict] = None): if data_preprocessor is None: data_preprocessor = {} data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') data_preprocessor = MODELS.build(data_preprocessor) super(BlipVQA, self).__init__( init_cfg=init_cfg, data_preprocessor=data_preprocessor) self.tokenizer = TOKENIZER.build(tokenizer) self.vision_backbone = MODELS.build(vision_backbone) self.multimodal_backbone = MODELS.build(multimodal_backbone) self.vqa_head = MODELS.build(head) @property def device(self): return next(self.parameters()).device
[docs] def forward( self, images: torch.Tensor, data_samples: Optional[List[DataSample]] = None, mode: str = 'loss', ): """The unified entry for a forward process in both training and test. - "loss": For training. Forward and return a dict of losses according to the given inputs and data samples. Note that this method doesn't handle neither back propagation nor optimizer updating, which are done in the :meth:`train_step`. - "predict": For testing. Forward and return a list of data_sample that contains pred_answer for each question. Args: images (Tensor): A batch of images. The shape of it should be (B, C, H, W) for images and (B, T, C, H, W) for videos. data_samples (List[DataSample], optional): The annotation data of every samples. Required when ``mode="loss"``. Defaults to None. mode (str): Return what kind of value. Defaults to 'loss'. Returns: The return type depends on ``mode``. - If ``mode="loss"``, return a dict of tensor. - If ``mode="predict"``, return a list of `DataSample` """ if mode == 'loss': return self.loss(images, data_samples) elif mode == 'predict': return self.predict(images, data_samples) else: raise RuntimeError(f'Invalid mode "{mode}".')
[docs] def extract_feat(self, images: torch.Tensor) -> torch.Tensor: """Extract features from the input tensor with shape (N, C, ..). Args: images (Tensor): A batch of images. The shape of it should be (B, C, H, W) for images and (B, T, C, H, W) for videos. Returns: visual_embeds (Tensor): The output features. """ # extract visual feature if images.ndim == 4: visual_embeds = self.vision_backbone(images)[0] elif images.ndim == 5: # [batch, T, C, H, W] -> [batch * T, C, H, W] bs = images.size(0) images = images.reshape(-1, *images.shape[2:]) visual_embeds = self.vision_backbone(images)[0] # [batch * num_segs, L, dim] -> [batch, num_segs * L, dim] visual_embeds = visual_embeds.reshape(bs, -1, *visual_embeds.shape[2:]) else: raise ValueError( f'Images with {images.ndim} dims is not supported.') return visual_embeds
[docs] def loss( self, images: torch.Tensor, data_samples: Optional[List[DataSample]] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: """generate train_loss from the input tensor and data_samples. Args: images (Tensor): A batch of images. The shape of it should be (B, C, H, W) for images and (B, T, C, H, W) for videos. data_samples (List[DataSample], optional): The annotation data of every samples. Returns: Dict[torch.Tensor]: The losses features. """ visual_embeds = self.extract_feat(images) image_atts = torch.ones( visual_embeds.size()[:-1], dtype=torch.long).to(self.device) questions = [] for sample in data_samples: questions.append(sample.get('question')) questions = self.tokenizer( questions, padding='longest', return_tensors='pt').to(self.device) questions.input_ids[:, 0] = \ self.tokenizer.additional_special_tokens_ids[0] # multimodal fusion multimodal_embeds = self.multimodal_backbone( questions.input_ids, attention_mask=questions.attention_mask, encoder_hidden_states=visual_embeds, encoder_attention_mask=image_atts, return_dict=True, ) # put answer from data_samples into tensor form answer_raw_text = [] for sample in data_samples: answer_raw_text.extend(sample.gt_answer) answer = self.tokenizer( answer_raw_text, padding='longest', return_tensors='pt').to(self.device) answer_targets = answer.input_ids.masked_fill( answer.input_ids == self.tokenizer.pad_token_id, -100) for sample in data_samples: # follow BLIP setting, set answer_weight to 0.2 for VG dataset. if not hasattr(sample, 'gt_answer_weight'): sample.gt_answer_weight = torch.tensor([0.2]) else: sample.gt_answer_weight = torch.tensor(sample.gt_answer_weight) answer_weight = torch.cat( [sample.gt_answer_weight for sample in data_samples], dim=0).to(self.device) answer_count = torch.tensor( [len(sample.gt_answer) for sample in data_samples]).to(self.device) question_states, question_atts = [], [] for b, n in enumerate(answer_count): question_states += [multimodal_embeds.last_hidden_state[b]] * n question_atts += [questions.attention_mask[b]] * n question_states = torch.stack(question_states, dim=0).to(self.device) question_atts = torch.stack(question_atts, dim=0).to(self.device) head_feats = dict( answer_input_ids=answer.input_ids, answer_attention_mask=answer.attention_mask, answer_weight=answer_weight, answer_targets=answer_targets, question_states=question_states, question_atts=question_atts, batch_size=len(data_samples), ) losses = self.vqa_head.loss(head_feats) return losses
[docs] def predict( self, images: torch.Tensor, data_samples: Optional[List[DataSample]] = None, ): """update data_samples that contain pred_answer for each question. Args: images (Tensor): A batch of images. The shape of it should be (B, C, H, W) for images and (B, T, C, H, W) for videos. data_samples (List[DataSample], optional): The annotation data of every samples. Returns: Dict[torch.Tensor]: The losses features. """ visual_embeds = self.extract_feat(images) image_atts = torch.ones( visual_embeds.size()[:-1], dtype=torch.long).to(self.device) questions = [] for sample in data_samples: questions.append(sample.get('question')) questions = self.tokenizer( questions, padding='longest', return_tensors='pt').to(self.device) questions.input_ids[:, 0] = \ self.tokenizer.additional_special_tokens_ids[0] # multimodal fusion multimodal_embeds = self.multimodal_backbone( questions.input_ids, attention_mask=questions.attention_mask, encoder_hidden_states=visual_embeds, encoder_attention_mask=image_atts, return_dict=True, ) if self.vqa_head.inference_method == 'rank': answer_candidates = self.tokenizer( self.vqa_head.answer_list, padding='longest', return_tensors='pt').to(self.device) answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id elif self.vqa_head.inference_method == 'generate': answer_candidates = None head_feats = dict( multimodal_embeds=multimodal_embeds.last_hidden_state, question_atts=questions.attention_mask, answer_candidates=answer_candidates, bos_token_id=self.tokenizer.bos_token_id, sep_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id, ) if self.vqa_head.inference_method == 'rank': answers = self.vqa_head.predict(head_feats) for answer, data_sample in zip(answers, data_samples): data_sample.pred_answer = answer elif self.vqa_head.inference_method == 'generate': outputs = self.vqa_head.predict(head_feats) for output, data_sample in zip(outputs, data_samples): data_sample.pred_answer = self.tokenizer.decode( output, skip_special_tokens=True) return data_samples
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.