diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 39506478f17926..bd61a1cbd781e7 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -814,6 +814,59 @@ def get_image_features( return image_features + @add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING) + def get_multimodal_features( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + multimodal_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The multimodal embeddings + obtained by applying the image embeddings to the text encoder using the cross-attention mechanism. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipModel + + >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = ["a photo of a cat", "a photo of a dog"] + >>> inputs = processor(images=image, text=texts, padding=True, return_tensors="pt") + + >>> multimodal_features = model.get_multimodal_features(**inputs) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=True, + output_hidden_states=True, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] # pooled_output + multimodal_features = self.text_projection(pooled_output) + + return multimodal_features + @add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BlipOutput, config_class=BlipConfig) def forward( diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 86ea1a8e363607..4caba63a310462 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -582,6 +582,63 @@ def test_model_from_pretrained(self): model = BlipModel.from_pretrained(model_name) self.assertIsNotNone(model) + def test_get_image_features(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + keys_to_pop = ["input_ids", "attention_mask", "return_loss"] + + for key in keys_to_pop: + inputs_dict.pop(key) + + model = BlipModel(config).to(torch_device) + model.eval() + image_features = model.get_image_features(**inputs_dict) + self.assertEqual( + image_features.shape, + ( + self.model_tester.batch_size, + model.projection_dim, + ), + ) + + def test_get_text_features(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + keys_to_pop = ["pixel_values", "return_loss"] + + for key in keys_to_pop: + inputs_dict.pop(key) + + model = BlipModel(config).to(torch_device) + model.eval() + text_features = model.get_text_features(**inputs_dict) + self.assertEqual( + text_features.shape, + ( + self.model_tester.batch_size, + model.projection_dim, + ), + ) + + def test_get_multimodal_features(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + keys_to_pop = ["return_loss"] + + for key in keys_to_pop: + inputs_dict.pop(key) + + model = BlipModel(config).to(torch_device) + model.eval() + multimodal_features = model.get_multimodal_features(**inputs_dict) + self.assertEqual( + multimodal_features.shape, + ( + self.model_tester.batch_size, + model.projection_dim, + ), + ) + def test_pt_tf_model_equivalence(self): super().test_pt_tf_model_equivalence(allow_missing_keys=True)