Skip to content

Commit

Permalink
scale the attn based on the configuration settings
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 15, 2023
1 parent f512bda commit 0b04d83
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def flashattn_attn(
head_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
softmax_scale = 1 / (key.size(-1) ** 0.5) if self.scale_attn_weights else None
softmax_scale = (
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
)

query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
Expand Down

0 comments on commit 0b04d83

Please sign in to comment.