Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mistral] Add Flash Attention-2 support for mistral #26464

Merged
merged 28 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5d9bc48
add FA-2 support for mistral
younesbelkada Sep 28, 2023
0983d88
fixup
younesbelkada Sep 28, 2023
b8c9198
add sliding windows
younesbelkada Sep 28, 2023
d740849
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Sep 28, 2023
bd58ca7
fixing few nits
younesbelkada Sep 28, 2023
43b0289
v1 slicing cache - logits do not match
younesbelkada Sep 29, 2023
ed2616f
add comment
younesbelkada Sep 29, 2023
7cafc2d
fix bugs
younesbelkada Oct 2, 2023
2b8c7b4
more mem efficient
younesbelkada Oct 2, 2023
4a3387d
add warning once
younesbelkada Oct 2, 2023
885b601
add warning once
younesbelkada Oct 2, 2023
172d99a
oops
younesbelkada Oct 2, 2023
253b383
fixup
younesbelkada Oct 2, 2023
e4d0fb7
more comments
younesbelkada Oct 2, 2023
a245722
copy
younesbelkada Oct 2, 2023
3079896
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Oct 2, 2023
e71c50d
add safety checker
younesbelkada Oct 2, 2023
a21d903
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Oct 2, 2023
5d1f589
fixup
younesbelkada Oct 2, 2023
b478e04
Update src/transformers/models/mistral/modeling_mistral.py
younesbelkada Oct 3, 2023
2fe2f49
copied from
younesbelkada Oct 3, 2023
25789d1
up
younesbelkada Oct 3, 2023
05ec7f4
raise when padding side is right
younesbelkada Oct 3, 2023
5a79195
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Oct 3, 2023
f9a69bc
fixup
younesbelkada Oct 3, 2023
6a48dd3
add doc + few minor changes
younesbelkada Oct 3, 2023
76763c7
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Oct 3, 2023
c286946
fixup
younesbelkada Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Make sure to follow the installation guide on the repository mentioned above to
We natively support Flash Attention 2 for the following models:

- Llama
- Mistral
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
- Falcon

You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
Expand Down
241 changes: 238 additions & 3 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,52 @@
# limitations under the License.
""" PyTorch Mistral model."""
import math
import inspect
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_mistral import MistralConfig


if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa

_is_flash_using_slicing_windows = "window_size" in list(inspect.signature(flash_attn_func).parameters)


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "MistralConfig"


def _get_unpad_data(padding_mask):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


def _make_sliding_window_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
Expand Down Expand Up @@ -226,6 +253,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -291,11 +319,208 @@ def forward(
return attn_output, attn_weights, past_key_value


class MistralFlashAttention2(MistralAttention):
"""
Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
):
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

use_sliding_windows = _is_flash_using_slicing_windows and kv_seq_len > self.config.sliding_window and not self.training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo, should be "_is_flash_using_sliding_windows"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I don't think this bool is needed.

If self.config.sliding_window is not None -> use sliding_window always, whether training or inferencing no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm good point, for some reason I thought that feature works only for inference (from the source code's readme: https://github.com/mistralai/mistral-src#sliding-window-to-speed-up-inference-and-reduce-memory-pressure I have read "speed up inference" so I thought that was only available for inference) - will remove that condition


if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# TODO: llama does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this should be included since its peft only always casts to float16 even if input is bfloat16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MistralRMSNorm behaves exactly as LLamaRMSNorm so it will silently cast the hidden states in fp32, therefore this is needed. As mentioned offline I will address a proper fix for bf16 issues


# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

attn_output = self._flash_attention_forward(query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate, use_sliding_windows=use_sliding_windows)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None, use_sliding_windows=False
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if padding_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

if not use_sliding_windows:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
)
else:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
window_size=(self.config.sliding_window, self.config.sliding_window)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could probably be factored ? something like window_size=(self.config.sliding_window or -1, -1) ?


attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True, window_size=(self.config.sliding_window // 2, self.config.sliding_window // 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment on factoring.
I have no idea what's going on with // 2 here but I don't know what padding_mask is :|

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah ignore the //2 I used it for testing purpose :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

padding mask is the "pure" attention mask (not causal mask) 0 if padding token 1 if not --> I use it in the control flow for flash attention modules whether I need to pad / unpad or no

)

return attn_output

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape

key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MistralAttention(config=config)
self.self_attn = (
MistralAttention(config=config)
if not getattr(config, "_flash_attn_2_enabled", False)
else MistralFlashAttention2(config)
)
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand All @@ -308,6 +533,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -335,6 +561,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -382,6 +609,7 @@ class MistralPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MistralDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down Expand Up @@ -569,11 +797,17 @@ def forward(

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

padding_mask = None

# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
elif 0 in attention_mask:
padding_mask = attention_mask

attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
Expand Down Expand Up @@ -607,7 +841,7 @@ def forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)

return custom_forward

Expand All @@ -625,6 +859,7 @@ def custom_forward(*inputs):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)

hidden_states = layer_outputs[0]
Expand Down
Loading