From 708fb9e5bbfc42dffde2665007a715aa5ade9f8c Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 13 Jul 2024 01:52:31 +0200 Subject: [PATCH 1/4] fix mask creation of gpt2 and gpt_neox caused by me --- src/transformers/models/gpt2/modeling_gpt2.py | 23 ++++++------ .../models/gpt_neox/modeling_gpt_neox.py | 35 +++++++++---------- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 2b9d300aa9e1bc..f4baeefa97c88d 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1030,18 +1030,17 @@ def forward( # Attention mask. _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None - if attention_mask is not None: - attention_mask = attention_mask.view(batch_size, -1) - if self._attn_implementation == "flash_attention_2": - attention_mask = attention_mask if 0 in attention_mask else None - elif _use_sdpa: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask=attention_mask, - input_shape=(batch_size, input_shape[-1]), - inputs_embeds=inputs_embeds, - past_key_values_length=past_length, - ) - else: + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif _use_sdpa: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(batch_size, input_shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) + else: + if attention_mask is not None: # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index fd5e0b4fe62e25..9b51a8f77969ec 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -824,25 +824,22 @@ def forward( inputs_embeds = self.embed_in(input_ids) # Attention mask. - if attention_mask is not None: - assert batch_size > 0, "batch_size has to be defined and > 0" - attention_mask = attention_mask.view(batch_size, -1) - if self._attn_implementation == "flash_attention_2": - attention_mask = attention_mask if 0 in attention_mask else None - elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask=attention_mask, - input_shape=(batch_size, seq_length), - inputs_embeds=inputs_embeds, - past_key_values_length=past_length, - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask=attention_mask, - input_shape=(batch_size, seq_length), - inputs_embeds=inputs_embeds, - past_key_values_length=past_length, - ) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) + else: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head From 7f1fb6c7a6dd194a2b56366af4cabdb5cfed39c2 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Sat, 13 Jul 2024 02:17:34 +0200 Subject: [PATCH 2/4] forgot the reshape of masks when shape > 2 --- src/transformers/models/gpt2/modeling_gpt2.py | 1 + src/transformers/models/gpt_neox/modeling_gpt_neox.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index f4baeefa97c88d..7a51cb3eb2cdb8 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1030,6 +1030,7 @@ def forward( # Attention mask. _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None + attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None if self._attn_implementation == "flash_attention_2": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif _use_sdpa: diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9b51a8f77969ec..32988e88df34a8 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -824,6 +824,7 @@ def forward( inputs_embeds = self.embed_in(input_ids) # Attention mask. + attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None if self._attn_implementation == "flash_attention_2": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None: From a01e26fd996970ef7d447d07db96c68bd5c132a1 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 16 Jul 2024 18:53:47 +0200 Subject: [PATCH 3/4] add tests for gpt neox and gpt2 --- tests/models/gpt2/test_modeling_gpt2.py | 34 +++++++++++++++++++ .../models/gpt_neox/test_modeling_gpt_neox.py | 34 +++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 5755658288f568..82a420a5442d4e 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -426,6 +426,36 @@ def create_and_check_gpt2_weight_initialization(self, config, *args): self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001) self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01) + def create_and_check_cached_forward_with_and_without_attention_mask(self, config, input_ids, *args): + # Relevant issue: https://github.com/huggingface/transformers/issues/31943 + model = GPT2Model(config) + model.to(torch_device) + model.eval() + + # We want this for SDPA, eager works without a `None` attention mask + assert ( + model.config._attn_implementation == "sdpa" + ), "This test assumes the model to have the SDPA implementation for its attention calculations." + + # Prepare cache and non_cache input, needs a full attention mask + cached_len = input_ids.shape[-1] // 2 + input_mask = torch.ones(size=input_ids.size()).to(torch_device) + cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]} + non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask} + + # Cached forward once with the attention mask provided and the other time without it (which should assume full attention) + cache_outputs = model(**cache_inputs) + full_outputs_with_attention_mask = model( + **non_cache_inputs, past_key_values=cache_outputs.past_key_values + ).last_hidden_state + full_outputs_without_attention_mask = model( + non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values + ).last_hidden_state + + self.parent.assertTrue( + torch.allclose(full_outputs_with_attention_mask, full_outputs_without_attention_mask, atol=1e-5) + ) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -570,6 +600,10 @@ def test_gpt2_weight_initialization(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt2_weight_initialization(*config_and_inputs) + def test_cached_forward_with_and_without_attention_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_cached_forward_with_and_without_attention_mask(*config_and_inputs) + @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 51a4d235c3bc5f..b28c32a30d4656 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -219,6 +219,36 @@ def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, in # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + def create_and_check_cached_forward_with_and_without_attention_mask(self, config, input_ids, *args): + # Relevant issue: https://github.com/huggingface/transformers/issues/31943 + model = GPTNeoXModel(config) + model.to(torch_device) + model.eval() + + # We want this for SDPA, eager works without a `None` attention mask + assert ( + model.config._attn_implementation == "sdpa" + ), "This test assumes the model to have the SDPA implementation for its attention calculations." + + # Prepare cache and non_cache input, needs a full attention mask + cached_len = input_ids.shape[-1] // 2 + input_mask = torch.ones(size=input_ids.size()).to(torch_device) + cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]} + non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask} + + # Cached forward once with the attention mask provided and the other time without it (which should assume full attention) + cache_outputs = model(**cache_inputs) + full_outputs_with_attention_mask = model( + **non_cache_inputs, past_key_values=cache_outputs.past_key_values + ).last_hidden_state + full_outputs_without_attention_mask = model( + non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values + ).last_hidden_state + + self.parent.assertTrue( + torch.allclose(full_outputs_with_attention_mask, full_outputs_without_attention_mask, atol=1e-5) + ) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, input_ids, input_mask, token_labels = config_and_inputs @@ -300,6 +330,10 @@ def test_model_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + def test_cached_forward_with_and_without_attention_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_cached_forward_with_and_without_attention_mask(*config_and_inputs) + @unittest.skip(reason="Feed forward chunking is not implemented") def test_feed_forward_chunking(self): pass From 431295585c61c0a84ce5fd89a47f53eea5dad3aa Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 16 Jul 2024 18:54:50 +0200 Subject: [PATCH 4/4] nit on a comment --- tests/models/gpt2/test_modeling_gpt2.py | 2 +- tests/models/gpt_neox/test_modeling_gpt_neox.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 82a420a5442d4e..3f96c20ab2dbd9 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -432,7 +432,7 @@ def create_and_check_cached_forward_with_and_without_attention_mask(self, config model.to(torch_device) model.eval() - # We want this for SDPA, eager works without a `None` attention mask + # We want this for SDPA, eager works with a `None` attention mask assert ( model.config._attn_implementation == "sdpa" ), "This test assumes the model to have the SDPA implementation for its attention calculations." diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index b28c32a30d4656..af162f50713e96 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -225,7 +225,7 @@ def create_and_check_cached_forward_with_and_without_attention_mask(self, config model.to(torch_device) model.eval() - # We want this for SDPA, eager works without a `None` attention mask + # We want this for SDPA, eager works with a `None` attention mask assert ( model.config._attn_implementation == "sdpa" ), "This test assumes the model to have the SDPA implementation for its attention calculations."