From 47f47d56bda0a8000c48969c3e073b6d59a985d2 Mon Sep 17 00:00:00 2001
From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date: Tue, 3 Oct 2023 13:44:46 +0200
Subject: [PATCH] [`Mistral`] Add Flash Attention-2 support for `mistral`
(#26464)
* add FA-2 support for mistral
* fixup
* add sliding windows
* fixing few nits
* v1 slicing cache - logits do not match
* add comment
* fix bugs
* more mem efficient
* add warning once
* add warning once
* oops
* fixup
* more comments
* copy
* add safety checker
* fixup
* Update src/transformers/models/mistral/modeling_mistral.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* copied from
* up
* raise when padding side is right
* fixup
* add doc + few minor changes
* fixup
---------
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
---
docs/source/en/model_doc/mistral.md | 45 +++
docs/source/en/perf_infer_gpu_one.md | 1 +
.../models/mistral/modeling_mistral.py | 318 +++++++++++++++++-
tests/models/mistral/test_modeling_mistral.py | 74 +++-
tests/test_modeling_common.py | 2 +-
5 files changed, 435 insertions(+), 5 deletions(-)
diff --git a/docs/source/en/model_doc/mistral.md b/docs/source/en/model_doc/mistral.md
index 578906abe643ee..4a8257d25d6a9e 100644
--- a/docs/source/en/model_doc/mistral.md
+++ b/docs/source/en/model_doc/mistral.md
@@ -82,6 +82,51 @@ tokenizer = LlamaTokenizer.from_pretrained("/output/path")
model = MistralForCausalLM.from_pretrained("/output/path")
```
+## Combining Mistral and Flash Attention 2
+
+First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
+
+```bash
+pip install -U flash-attn --no-build-isolation
+```
+
+Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
+
+To load and run a model using Flash Attention 2, refer to the snippet below:
+
+```python
+>>> import torch
+>>> from transformers import AutoModelForCausalLM, AutoTokenizer
+>>> device = "cuda" # the device to load the model onto
+
+>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
+>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
+
+>>> prompt = "My favourite condiment is"
+
+>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
+>>> model.to(device)
+
+>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
+>>> tokenizer.batch_decode(generated_ids)[0]
+"The expected outupt"
+```
+
+### Expected speedups
+
+Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mistral-7B-v0.1` checkpoint and the Flash Attention 2 version of the model.
+
+
+
+
+
+### Sliding window Attention
+
+The current implementation supports the sliding window attention mechanism and memory efficient cache management.
+To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).
+
+The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.
+
## The Mistral Team
Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index f0c0bf0b107154..d24299012e9fe1 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -32,6 +32,7 @@ Make sure to follow the installation guide on the repository mentioned above to
We natively support Flash Attention 2 for the following models:
- Llama
+- Mistral
- Falcon
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py
index 62cfca29465f3e..8edd0fc60e6299 100644
--- a/src/transformers/models/mistral/modeling_mistral.py
+++ b/src/transformers/models/mistral/modeling_mistral.py
@@ -18,10 +18,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Mistral model."""
+import inspect
import math
from typing import List, Optional, Tuple, Union
import torch
+import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -29,15 +31,41 @@
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
-from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_available,
+ logging,
+ replace_return_docstrings,
+)
from .configuration_mistral import MistralConfig
+if is_flash_attn_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
+
+
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig"
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(padding_mask):
+ seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
def _make_sliding_window_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
@@ -226,6 +254,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
+ padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -291,11 +320,271 @@ def forward(
return attn_output, attn_weights, past_key_value
+class MistralFlashAttention2(MistralAttention):
+ """
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ use_sliding_windows = (
+ _flash_supports_window_size
+ and hasattr(self.config, "sliding_window") is not None
+ and kv_seq_len > self.config.sliding_window
+ )
+
+ if not _flash_supports_window_size:
+ logger.warning_once(
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
+ " make sure to upgrade flash-attn library."
+ )
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
+ slicing_tokens = kv_seq_len - self.config.sliding_window
+
+ past_key = past_key_value[0]
+ past_value = past_key_value[1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ past_key_value = (past_key, past_value)
+
+ if padding_mask is not None:
+ padding_mask = padding_mask[:, slicing_tokens:]
+ padding_mask = torch.cat([padding_mask, torch.ones_like(padding_mask[:, -1:])], dim=-1)
+
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # TODO: Mistral does not have dropout in the config??
+ # It is recommended to use dropout with FA according to the docs
+ # when training.
+ dropout_rate = 0.0 # if not self.training else self.attn_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ logger.warning_once(
+ "The input hidden states seems to be silently casted in float32, this might be related to"
+ " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ " float16."
+ )
+
+ query_states = query_states.to(torch.float16)
+ key_states = key_states.to(torch.float16)
+ value_states = value_states.to(torch.float16)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ padding_mask,
+ q_len,
+ dropout=dropout_rate,
+ use_sliding_windows=use_sliding_windows,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ padding_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ use_sliding_windows=False,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ padding_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_sliding_windows (`bool`, *optional*):
+ Whether to activate sliding window attention.
+ """
+ # Contains at least one padding token in the sequence
+ if padding_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, padding_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ if not use_sliding_windows:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=True,
+ )
+ else:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=True,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ if not use_sliding_windows:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
+ )
+ else:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=True,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ return attn_output
+
+ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
+
+ # On the first iteration we need to properly re-create the padding mask
+ # by slicing it on the proper place
+ if kv_seq_len != padding_mask.shape[-1]:
+ padding_mask_num_tokens = padding_mask.shape[-1]
+ padding_mask = padding_mask[:, padding_mask_num_tokens - kv_seq_len :]
+
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
+
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ padding_mask = padding_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = MistralAttention(config=config)
+ self.self_attn = (
+ MistralAttention(config=config)
+ if not getattr(config, "_flash_attn_2_enabled", False)
+ else MistralFlashAttention2(config)
+ )
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -308,6 +597,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
+ padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -335,6 +625,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
+ padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
@@ -382,6 +673,7 @@ class MistralPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MistralDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
def _init_weights(self, module):
std = self.config.initializer_range
@@ -569,11 +861,30 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
+
+ padding_mask = None
+
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
+ elif 0 in attention_mask:
+ padding_mask = attention_mask
+
+ if (
+ padding_mask is not None
+ and hasattr(self.config, "_flash_attn_2_enabled")
+ and self.config._flash_attn_2_enabled
+ ):
+ is_padding_right = padding_mask[:, -1].sum().item() != batch_size
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
@@ -607,7 +918,7 @@ def forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
- return module(*inputs, past_key_value, output_attentions)
+ return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
return custom_forward
@@ -625,6 +936,7 @@ def custom_forward(*inputs):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
+ padding_mask=padding_mask,
)
hidden_states = layer_outputs[0]
diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py
index df3bcf9d671fe2..403f2cc7347041 100644
--- a/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/models/mistral/test_modeling_mistral.py
@@ -15,10 +15,13 @@
""" Testing suite for the PyTorch Mistral model. """
+import tempfile
import unittest
+from pytest import mark
+
from transformers import AutoTokenizer, MistralConfig, is_torch_available
-from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.testing_utils import require_flash_attn, require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -351,6 +354,75 @@ def test_save_load_fast_init_from_base(self):
def test_past_key_values_format(self):
pass
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_2:
+ return
+
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
+ ).to(torch_device)
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_inference_padding_right(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_2:
+ return
+
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
+ )
+ model.to(torch_device)
+
+ dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device)
+
+ _ = model(dummy_input, output_hidden_states=True).hidden_states[-1]
+ with self.assertRaises(ValueError):
+ _ = model_fa(
+ dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True
+ ).hidden_states[-1]
+
@require_torch
class MistralIntegrationTest(unittest.TestCase):
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 2789fe32c143f0..0a17c13a01215b 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -2926,7 +2926,7 @@ def test_flash_attn_2_generate_use_cache(self):
model.save_pretrained(tmpdirname)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
- dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True