Skip to content

Commit

Permalink
adds llama and mistral dropout support (#858)
Browse files Browse the repository at this point in the history
* adds llama and mistral dropout support

* gracefully handle attention dropout if not available yet
  • Loading branch information
winglian committed Nov 15, 2023
1 parent 1470650 commit db8a8af
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
14 changes: 11 additions & 3 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape

dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)

if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
Expand All @@ -330,7 +332,12 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...")

output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
Expand All @@ -353,7 +360,7 @@ def flashattn_forward(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
0.0,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
)
Expand All @@ -366,6 +373,7 @@ def flashattn_forward(
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal,
)
else:
Expand Down Expand Up @@ -398,7 +406,7 @@ def flashattn_forward(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
)
Expand Down
9 changes: 6 additions & 3 deletions src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape

dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)

if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
Expand All @@ -213,7 +215,7 @@ def flashattn_forward(
qkv,
cu_seqlens,
max_seqlen,
0.0,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
window_size=window_size,
Expand All @@ -239,7 +241,7 @@ def flashattn_forward(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
0.0,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
Expand All @@ -253,6 +255,7 @@ def flashattn_forward(
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal,
window_size=window_size,
)
Expand Down Expand Up @@ -286,7 +289,7 @@ def flashattn_forward(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
Expand Down

0 comments on commit db8a8af

Please sign in to comment.