Skip to content

Commit

Permalink
BlipModel: get_multimodal_features method (#30438)
Browse files Browse the repository at this point in the history
* add_blip_get_multimodal_feautres

* Fix docstring error

* reimplement get_multimodal_features

* fix error

* recheck code quality

* add new necessary tests
  • Loading branch information
XavierSpycy authored Apr 30, 2024
1 parent 9112520 commit 0cdb6b3
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,59 @@ def get_image_features(

return image_features

@add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
def get_multimodal_features(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
multimodal_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The multimodal embeddings
obtained by applying the image embeddings to the text encoder using the cross-attention mechanism.
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, BlipModel
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["a photo of a cat", "a photo of a dog"]
>>> inputs = processor(images=image, text=texts, padding=True, return_tensors="pt")
>>> multimodal_features = model.get_multimodal_features(**inputs)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=True,
output_hidden_states=True,
return_dict=return_dict,
)

image_embeds = vision_outputs[0]
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)

text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=return_dict,
)

pooled_output = text_outputs[1] # pooled_output
multimodal_features = self.text_projection(pooled_output)

return multimodal_features

@add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipOutput, config_class=BlipConfig)
def forward(
Expand Down
57 changes: 57 additions & 0 deletions tests/models/blip/test_modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,63 @@ def test_model_from_pretrained(self):
model = BlipModel.from_pretrained(model_name)
self.assertIsNotNone(model)

def test_get_image_features(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

keys_to_pop = ["input_ids", "attention_mask", "return_loss"]

for key in keys_to_pop:
inputs_dict.pop(key)

model = BlipModel(config).to(torch_device)
model.eval()
image_features = model.get_image_features(**inputs_dict)
self.assertEqual(
image_features.shape,
(
self.model_tester.batch_size,
model.projection_dim,
),
)

def test_get_text_features(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

keys_to_pop = ["pixel_values", "return_loss"]

for key in keys_to_pop:
inputs_dict.pop(key)

model = BlipModel(config).to(torch_device)
model.eval()
text_features = model.get_text_features(**inputs_dict)
self.assertEqual(
text_features.shape,
(
self.model_tester.batch_size,
model.projection_dim,
),
)

def test_get_multimodal_features(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

keys_to_pop = ["return_loss"]

for key in keys_to_pop:
inputs_dict.pop(key)

model = BlipModel(config).to(torch_device)
model.eval()
multimodal_features = model.get_multimodal_features(**inputs_dict)
self.assertEqual(
multimodal_features.shape,
(
self.model_tester.batch_size,
model.projection_dim,
),
)

def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)

Expand Down

0 comments on commit 0cdb6b3

Please sign in to comment.