diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index 56d590bad7d8..10a63abf2fb0 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -266,29 +266,29 @@ def get_rocm_test_compilation_infos( MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 2, 1), MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 2, 1, 1), MMASchedule("MFMA_F32_16x16x4_F32", 2, 2, 1, 1, 2), - MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1), - MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2), - MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1), - MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 2, 1, 1), - MMASchedule("MFMA_F16_16x16x16_F32", 2, 2, 1, 1, 1), - MMASchedule("MFMA_F16_16x16x16_F32", 2, 4, 2, 1, 2), - MMASchedule("MFMA_F16_16x16x16_F32", 4, 2, 4, 2, 2), - MMASchedule("MFMA_F16_32x32x8_F32", 1, 1, 1, 2, 2), - MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1), - MMASchedule("MFMA_F16_32x32x8_F32", 1, 4, 2, 1, 2), - MMASchedule("MFMA_F16_32x32x8_F32", 4, 2, 1, 2, 4), - MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 1, 1, 1, 1, 1), - MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 2, 2, 1, 1, 2), - MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 1, 4, 1, 1), - MMASchedule("MFMA_F8E4M3FNUZ_16x16x32_F32", 4, 2, 4, 2, 1), - MMASchedule("MFMA_I8_16x16x32_I32", 1, 1, 1, 1, 1), - MMASchedule("MFMA_I8_16x16x32_I32", 2, 2, 1, 1, 2), - MMASchedule("MFMA_I8_16x16x32_I32", 4, 1, 4, 1, 1), - MMASchedule("MFMA_I8_16x16x32_I32", 4, 2, 4, 2, 1), - MMASchedule("MFMA_I8_32x32x16_I32", 1, 1, 1, 1, 1), - MMASchedule("MFMA_I8_32x32x16_I32", 2, 2, 1, 1, 2), - MMASchedule("MFMA_I8_32x32x16_I32", 4, 1, 1, 2, 2), - MMASchedule("MFMA_I8_32x32x16_I32", 4, 2, 2, 2, 2), + MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 1, 1, 1), + MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 1, 1, 2), + MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 1, 2, 1), + MMASchedule("MFMA_F32_16x16x16_F16", 1, 1, 2, 1, 1), + MMASchedule("MFMA_F32_16x16x16_F16", 2, 2, 1, 1, 1), + MMASchedule("MFMA_F32_16x16x16_F16", 2, 4, 2, 1, 2), + MMASchedule("MFMA_F32_16x16x16_F16", 4, 2, 4, 2, 2), + MMASchedule("MFMA_F32_32x32x8_F16", 1, 1, 1, 2, 2), + MMASchedule("MFMA_F32_32x32x8_F16", 2, 2, 1, 1, 1), + MMASchedule("MFMA_F32_32x32x8_F16", 1, 4, 2, 1, 2), + MMASchedule("MFMA_F32_32x32x8_F16", 4, 2, 1, 2, 4), + MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 1, 1, 1, 1, 1), + MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 2, 2, 1, 1, 2), + MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 1, 4, 1, 1), + MMASchedule("MFMA_F32_16x16x32_F8E4M3FNUZ", 4, 2, 4, 2, 1), + MMASchedule("MFMA_I32_16x16x32_I8", 1, 1, 1, 1, 1), + MMASchedule("MFMA_I32_16x16x32_I8", 2, 2, 1, 1, 2), + MMASchedule("MFMA_I32_16x16x32_I8", 4, 1, 4, 1, 1), + MMASchedule("MFMA_I32_16x16x32_I8", 4, 2, 4, 2, 1), + MMASchedule("MFMA_I32_32x32x16_I8", 1, 1, 1, 1, 1), + MMASchedule("MFMA_I32_32x32x16_I8", 2, 2, 1, 1, 2), + MMASchedule("MFMA_I32_32x32x16_I8", 4, 1, 1, 2, 2), + MMASchedule("MFMA_I32_32x32x16_I8", 4, 2, 2, 2, 2), ] elif intrinsic == "WMMA": schedules = [ @@ -319,22 +319,22 @@ def get_rocm_test_compilation_infos( wg_tile_m = schedule.m_count * schedule.m_tile_count * 16 wg_tile_n = schedule.n_count * schedule.n_tile_count * 16 wg_tile_k = schedule.k_tile_count * 4 - elif schedule.intrinsic == "MFMA_F16_16x16x16_F32": + elif schedule.intrinsic == "MFMA_F32_16x16x16_F16": wg_tile_m = schedule.m_count * schedule.m_tile_count * 16 wg_tile_n = schedule.n_count * schedule.n_tile_count * 16 wg_tile_k = schedule.k_tile_count * 16 - elif schedule.intrinsic == "MFMA_F16_32x32x8_F32": + elif schedule.intrinsic == "MFMA_F32_32x32x8_F16": wg_tile_m = schedule.m_count * schedule.m_tile_count * 32 wg_tile_n = schedule.n_count * schedule.n_tile_count * 32 wg_tile_k = schedule.k_tile_count * 8 elif ( - schedule.intrinsic == "MFMA_I8_16x16x32_I32" - or schedule.intrinsic == "MFMA_F8E4M3FNUZ_16x16x32_F32" + schedule.intrinsic == "MFMA_I32_16x16x32_I8" + or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FNUZ" ): wg_tile_m = schedule.m_count * schedule.m_tile_count * 16 wg_tile_n = schedule.n_count * schedule.n_tile_count * 16 wg_tile_k = schedule.k_tile_count * 32 - elif schedule.intrinsic == "MFMA_I8_32x32x16_I32": + elif schedule.intrinsic == "MFMA_I32_32x32x16_I8": wg_tile_m = schedule.m_count * schedule.m_tile_count * 32 wg_tile_n = schedule.n_count * schedule.n_tile_count * 32 wg_tile_k = schedule.k_tile_count * 16