Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shifted sparse attention #973

Merged
merged 17 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:

# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Resume from a specific checkpoint dir
resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
Expand Down
1 change: 1 addition & 0 deletions examples/code-llama/13b/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
Expand Down
1 change: 1 addition & 0 deletions examples/code-llama/34b/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
Expand Down
1 change: 1 addition & 0 deletions examples/code-llama/7b/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
Expand Down
1 change: 1 addition & 0 deletions examples/llama-2/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
Expand Down
1 change: 1 addition & 0 deletions examples/openllama-3b/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ logging_steps: 1
xformers_attention:
flash_attention: true
gptq_groupsize:
s2_attention:
gptq_model_v1:
warmup_steps: 20
evals_per_epoch: 4
Expand Down
141 changes: 140 additions & 1 deletion src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,20 @@ def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
use_shifted_sparse_attn: Optional[bool] = False,
):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if use_shifted_sparse_attn:
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
flashattn_forward_with_s2attn
)
else:
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
flashattn_forward
)

if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = (
Expand Down Expand Up @@ -213,6 +222,136 @@ def _prepare_decoder_attention_mask(
return attention_mask


GROUP_SIZE_RATIO = 1 / 4


def flashattn_forward_with_s2attn(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel

From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py

attention_mask: [bsz, q_len]

`cu_seqlens` will be ignored if provided
`max_seqlen` will be ignored if provided
"""
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)

bsz, q_len, _ = hidden_states.size()

query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
# pylint: disable=duplicate-code

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)

# Past Key value support
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py

# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]

# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask

key_padding_mask = attention_mask.repeat(2, 1)
nheads = qkv.shape[-2]
# shift

group_size = int(q_len * GROUP_SIZE_RATIO)
if q_len % group_size > 0:
raise ValueError(
f"q_len {q_len} should be divisible by group size {group_size}."
)

qkv = (
qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim)
.permute(0, 3, 1, 2, 4, 5)
.reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim)
)
x = rearrange( # pylint: disable=invalid-name
qkv, "b s three h d -> b s (three h d)"
)
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
cu_q_len_tmp = torch.arange(
0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype
)
cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(
bsz, 1
) + cu_q_lens[:-1].unsqueeze(-1)
cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)

x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
),
"b s (h d) -> b s h d",
h=nheads // 2,
)
output = (
output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim)
.transpose(1, 2)
.reshape(bsz, q_len, nheads, self.head_dim)
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value


def flashattn_forward(
self,
hidden_states: torch.Tensor,
Expand Down
61 changes: 44 additions & 17 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,31 +256,55 @@ def load_model(

replace_stablelm_attn_with_flash_attn(cfg.base_model)

if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
if cfg.sample_packing and cfg.s2_attention:
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
shifted-sparse attention does not current support sample packing."
winglian marked this conversation as resolved.
Show resolved Hide resolved
)

# Modify all llama derived models in one block
if cfg.is_llama_derived_model:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it applies to llama models, do we need to account for mistral here as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings this should work with mistral and mixtral too, right?

if cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)

LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
if cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, does it mean, FA won't be enabled for inference mode now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think FA was ever enabled for flash_attention. here's the original code:

    if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
        if cfg.device not in ["mps", "cpu"] and not inference:
            from axolotl.monkeypatch.llama_attn_hijack_flash import (
                replace_llama_attn_with_flash_attn,
            )

            LOG.info("patching with flash attention for sample packing")
            replace_llama_attn_with_flash_attn(
                packed=cfg.sample_packing,
                cross_entropy=cfg.flash_attn_cross_entropy,
                rms_norm=cfg.flash_attn_rms_norm,
            )

LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)

LOG.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
hijack_llama_sdp_attention,
)

LOG.info("patching with sdp attention")
hijack_llama_sdp_attention()
LOG.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)

# Modify mistral derived models
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn,
Expand Down Expand Up @@ -387,9 +411,12 @@ def load_model(
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)

# sample packing uses custom FA2 patch
if cfg.flash_attention:
if not cfg.sample_packing:
if cfg.s2_attention:
pass
if (
winglian marked this conversation as resolved.
Show resolved Hide resolved
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
Expand Down
Loading