Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paligemma- fix devices and dtype assignments #31008

Merged
merged 2 commits into from
May 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,15 @@ def _merge_input_ids_with_image_features(
pad_mask = input_ids == self.pad_token_id

# expand masks to match embedding dimension
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
# insert padding and text token embeddings
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
# insert image embeddings - the image mask is always less or equal to the sentence in length
final_embedding = final_embedding.masked_scatter(
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
image_mask.unsqueeze(-1).expand_as(final_embedding).to(device=final_embedding.device),
scaled_image_features.to(device=final_embedding.device, dtype=final_embedding.dtype),
)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
if attention_mask is not None:
Expand All @@ -329,10 +330,12 @@ def _merge_input_ids_with_image_features(
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
# unmask the prefill
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :] == 0, 0
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
Comment on lines 337 to 339
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one does not make sense to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the masked_fill, you need both tensors to be on the same device, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I read to fast.
So token type ids's device is not correctly inferred ? Or are we not creating the causal mask on the correct device? It should be created on the input or attention mask's device for consistency, since when it's used accelerate will transfer it accordingly I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right! will update to move to device at creation time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think we are setting the causal_mask to the correct device. It's the token_type_id device that is indeed not correctly inferred

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But from the comment of @SunMarc I would suspect both devices to be the same no

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, doesn't seem like it, tried on a multi-gpu env with device_map to auto and got
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc if you have an idea here - token_type_ids is created by the processor along with input_ids and passed to the forward normally

Copy link
Member

@SunMarc SunMarc May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, from the code and the image you shared, I see that token_type_ids is indeed on the same device as input_ids. However, since you created the causal_mask to be on the same device as inputs_embeds.device, token_type_ids and input_ids might not be on the same device.

            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )

where
dtype, device = inputs_embeds.dtype, inputs_embeds.device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, thanks! in that case, we can keep it as it is? The other way is to create the causal mask on the input_ids.device, I'm not sure if one is better than the other - inputs_embeds is much larger in general

padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
Expand Down Expand Up @@ -484,7 +487,7 @@ def forward(
# we use the input attention mask to shift the logits and labels, because it is 2D.
shift_attention_mask = input_attention_mask[..., 1:]
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
Expand Down
Loading