Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mask creations of GPTNeoX and GPT2 #31944

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,18 +1030,18 @@ def forward(

# Attention mask.
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, input_shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, input_shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
Expand Down
36 changes: 17 additions & 19 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,25 +824,23 @@ def forward(
inputs_embeds = self.embed_in(input_ids)

# Attention mask.
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
Loading