Skip to content

Commit

Permalink
Format code according to linter
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings authored and joecummings committed Dec 19, 2023
1 parent 4f6acd7 commit 42a0645
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
42 changes: 29 additions & 13 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ def replace_llama_attn_with_flash_attn(
_prepare_decoder_attention_mask
)
if use_shifted_sparse_attn:
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward_with_s2attn
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
flashattn_forward_with_s2attn
)
else:
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
flashattn_forward
)

if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
Expand Down Expand Up @@ -218,7 +222,9 @@ def _prepare_decoder_attention_mask(
return attention_mask


group_size_ratio = 1/4
GROUP_SIZE_RATIO = 1 / 4


def flashattn_forward_with_s2attn(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -301,18 +307,25 @@ def flashattn_forward_with_s2attn(
nheads = qkv.shape[-2]
# shift

group_size = int(q_len * group_size_ratio)
group_size = int(q_len * GROUP_SIZE_RATIO)
if q_len % group_size > 0:
raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
raise ValueError(
f"q_len {q_len} should be divisible by group size {group_size}."
)

qkv = qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,
q_len, 3,
self.num_heads // 2,
self.head_dim)
qkv = (
qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim)
.permute(0, 3, 1, 2, 4, 5)
.reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim)
)
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
cu_q_len_tmp = torch.arange(
0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype
)
cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(
bsz, 1
) + cu_q_lens[:-1].unsqueeze(-1)
cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)

x_unpad = rearrange(
Expand All @@ -328,8 +341,11 @@ def flashattn_forward_with_s2attn(
"b s (h d) -> b s h d",
h=nheads // 2,
)
output = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,
self.head_dim)
output = (
output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim)
.transpose(1, 2)
.reshape(bsz, q_len, nheads, self.head_dim)
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value


Expand Down
23 changes: 16 additions & 7 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,18 @@ def load_model(
replace_stablelm_attn_with_flash_attn(cfg.base_model)

if cfg.sample_packing and cfg.s2_attention:
raise ValueError("Received `sample_packing=true` and `s2_attention=true`; however, \
shifted-sparse attention does not current support sample packing.")
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
shifted-sparse attention does not current support sample packing."
)

# Modify all llama derived models in one block
if cfg.is_llama_derived_model:
if cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
replace_llama_attn_with_flash_attn,
)

if cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
LOG.info("patching with flash attention for sample packing")
Expand All @@ -230,16 +233,20 @@ def load_model(
packed=False,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True
use_shifted_sparse_attn=True,
)
elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)

LOG.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
hijack_llama_sdp_attention,
)

LOG.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.landmark_attention:
Expand All @@ -254,7 +261,9 @@ def load_model(
# Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
elif cfg.s2_attention:
raise NotImplementedError("Shifted-sparse attention not currently implemented without flash attention.")
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)

if cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
Expand Down

0 comments on commit 42a0645

Please sign in to comment.