diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index ef69b1a..b366cba 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -151,7 +151,9 @@ def supports_flash_attention(device_id=0): # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0) is_sm8x = major == 8 and minor >= 0 is_sm90 = major == 9 and minor == 0 - return is_sm8x or is_sm90 + dev_name = torch.cuda.get_device_properties(device_id).gcnArchName.split(":")[0] + is_compat_amd = dev_name in ("gfx90a", "gfx940", "gfx941", "gfx942") + return is_sm8x or is_sm90 or is_compat_amd def check_flash_attn_enabled(disable_flash_attn: bool, use_dolomite: bool) -> bool: