Skip to content

Commit

Permalink
Fix mask creations of GPTNeoX and GPT2 (huggingface#31944)
Browse files Browse the repository at this point in the history
* fix mask creation of gpt2 and gpt_neox caused by me

* forgot the reshape of masks when shape > 2

* add tests for gpt neox and gpt2

* nit on a comment
  • Loading branch information
vasqu authored and dataKim1201 committed Oct 7, 2024
1 parent 5f688e3 commit cc205f9
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 31 deletions.
24 changes: 12 additions & 12 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,18 +1030,18 @@ def forward(

# Attention mask.
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, input_shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, input_shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
Expand Down
36 changes: 17 additions & 19 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,25 +824,23 @@ def forward(
inputs_embeds = self.embed_in(input_ids)

# Attention mask.
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
34 changes: 34 additions & 0 deletions tests/models/gpt2/test_modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,36 @@ def create_and_check_gpt2_weight_initialization(self, config, *args):
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)

def create_and_check_cached_forward_with_and_without_attention_mask(self, config, input_ids, *args):
# Relevant issue: https://github.com/huggingface/transformers/issues/31943
model = GPT2Model(config)
model.to(torch_device)
model.eval()

# We want this for SDPA, eager works with a `None` attention mask
assert (
model.config._attn_implementation == "sdpa"
), "This test assumes the model to have the SDPA implementation for its attention calculations."

# Prepare cache and non_cache input, needs a full attention mask
cached_len = input_ids.shape[-1] // 2
input_mask = torch.ones(size=input_ids.size()).to(torch_device)
cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]}
non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask}

# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
cache_outputs = model(**cache_inputs)
full_outputs_with_attention_mask = model(
**non_cache_inputs, past_key_values=cache_outputs.past_key_values
).last_hidden_state
full_outputs_without_attention_mask = model(
non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values
).last_hidden_state

self.parent.assertTrue(
torch.allclose(full_outputs_with_attention_mask, full_outputs_without_attention_mask, atol=1e-5)
)

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

Expand Down Expand Up @@ -570,6 +600,10 @@ def test_gpt2_weight_initialization(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_weight_initialization(*config_and_inputs)

def test_cached_forward_with_and_without_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_cached_forward_with_and_without_attention_mask(*config_and_inputs)

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
Expand Down
34 changes: 34 additions & 0 deletions tests/models/gpt_neox/test_modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,36 @@ def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, in
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))

def create_and_check_cached_forward_with_and_without_attention_mask(self, config, input_ids, *args):
# Relevant issue: https://github.com/huggingface/transformers/issues/31943
model = GPTNeoXModel(config)
model.to(torch_device)
model.eval()

# We want this for SDPA, eager works with a `None` attention mask
assert (
model.config._attn_implementation == "sdpa"
), "This test assumes the model to have the SDPA implementation for its attention calculations."

# Prepare cache and non_cache input, needs a full attention mask
cached_len = input_ids.shape[-1] // 2
input_mask = torch.ones(size=input_ids.size()).to(torch_device)
cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]}
non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask}

# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
cache_outputs = model(**cache_inputs)
full_outputs_with_attention_mask = model(
**non_cache_inputs, past_key_values=cache_outputs.past_key_values
).last_hidden_state
full_outputs_without_attention_mask = model(
non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values
).last_hidden_state

self.parent.assertTrue(
torch.allclose(full_outputs_with_attention_mask, full_outputs_without_attention_mask, atol=1e-5)
)

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask, token_labels = config_and_inputs
Expand Down Expand Up @@ -300,6 +330,10 @@ def test_model_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)

def test_cached_forward_with_and_without_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_cached_forward_with_and_without_attention_mask(*config_and_inputs)

@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass
Expand Down

0 comments on commit cc205f9

Please sign in to comment.