Skip to content

Commit

Permalink
Include AMD in flash attn check
Browse files Browse the repository at this point in the history
Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com>
  • Loading branch information
Maxusmusti committed Oct 14, 2024
1 parent 8e3f553 commit 2915bf0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def supports_flash_attention(device_id=0):
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
return is_sm8x or is_sm90
dev_name = torch.cuda.get_device_properties(device_id).gcnArchName.split(":")[0]
is_compat_amd = dev_name in ("gfx90a", "gfx940", "gfx941", "gfx942")
return is_sm8x or is_sm90 or is_compat_amd


def check_flash_attn_enabled(disable_flash_attn: bool, use_dolomite: bool) -> bool:
Expand Down

0 comments on commit 2915bf0

Please sign in to comment.