Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral: Sliding Window Attention with Flash Attention and Sample Packing #732

Merged
merged 10 commits into from
Oct 16, 2023
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def parse_requirements():
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn>=2.2.1",
"flash-attn>=2.3.0",
],
"deepspeed": [
"deepspeed",
Expand Down
109 changes: 104 additions & 5 deletions src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
Expand Down Expand Up @@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
)


@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)

if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
Expand All @@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
sliding_window,
): # pylint: disable=unused-argument
# [bsz, seq_len]
if attention_mask is None:
return attention_mask

# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")

return attention_mask


def flashattn_forward(
self,
self: OriginalMistralAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -91,10 +150,41 @@ def flashattn_forward(
query_states, key_states, cos, sin, position_ids
)

use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)

if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if (
hasattr(self.config, "sliding_window")
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window

past_key = past_key_value[0]
past_value = past_key_value[1]

past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()

if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)

past_key_value = (past_key, past_value) if use_cache else None

if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

Expand All @@ -120,7 +210,13 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...")

output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
qkv,
cu_seqlens,
max_seqlen,
0.0,
softmax_scale=None,
causal=True,
window_size=window_size,
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
Expand All @@ -146,6 +242,7 @@ def flashattn_forward(
0.0,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)
else:
Expand All @@ -157,6 +254,7 @@ def flashattn_forward(
query_states,
torch.stack([key_states, value_states], 2),
causal=is_causal,
window_size=window_size,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
Expand Down Expand Up @@ -191,6 +289,7 @@ def flashattn_forward(
0.0,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)

Expand Down