Skip to content

Commit

Permalink
[GPU][NFC] Follow the official convention to define mfma/wmma attribu…
Browse files Browse the repository at this point in the history
…tes (#18127)

The LLVM intrinsics and official docs are all using
`[output_type]_MxNxK_[input_type]` format. The revision updates IREE's
definitions to follow the convention.

Some examples from official docs:

- https://gpuopen.com/learn/wmma_on_rdna3/
-
https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme/
- https://github.com/ROCm/amd_matrix_instruction_calculator

This patch is generated by the below commands:

```bash
cd compiler
sed -i "s/MFMA_F16_16x16x16_F32/MFMA_F32_16x16x16_F16/g" **/*.mlir
sed -i "s/MFMA_F16_16x16x16_F32/MFMA_F32_16x16x16_F16/g" **/*.td
sed -i "s/MFMA_F16_16x16x16_F32/MFMA_F32_16x16x16_F16/g" **/*.cpp
sed -i "s/MFMA_F16_16x16x16_F32/MFMA_F32_16x16x16_F16/g" **/*.h
sed -i "s/MFMA_F16_32x32x8_F32/MFMA_F32_32x32x8_F16/g" **/*.mlir
sed -i "s/MFMA_F16_32x32x8_F32/MFMA_F32_32x32x8_F16/g" **/*.td
sed -i "s/MFMA_F16_32x32x8_F32/MFMA_F32_32x32x8_F16/g" **/*.h
sed -i "s/MFMA_F16_32x32x8_F32/MFMA_F32_32x32x8_F16/g" **/*.cpp
sed -i "s/MFMA_F8E4M3FNUZ_16x16x32_F32/MFMA_F32_16x16x32_F8E4M3FNUZ/g" **/*.mlir
sed -i "s/MFMA_F8E4M3FNUZ_16x16x32_F32/MFMA_F32_16x16x32_F8E4M3FNUZ/g" **/*.td
sed -i "s/MFMA_F8E4M3FNUZ_16x16x32_F32/MFMA_F32_16x16x32_F8E4M3FNUZ/g" **/*.h
sed -i "s/MFMA_F8E4M3FNUZ_16x16x32_F32/MFMA_F32_16x16x32_F8E4M3FNUZ/g" **/*.cpp
sed -i "s/MFMA_I8_16x16x32_I32/MFMA_I32_16x16x32_I8/g" **/*.mlir
sed -i "s/MFMA_I8_16x16x32_I32/MFMA_I32_16x16x32_I8/g" **/*.td
sed -i "s/MFMA_I8_16x16x32_I32/MFMA_I32_16x16x32_I8/g" **/*.h
sed -i "s/MFMA_I8_16x16x32_I32/MFMA_I32_16x16x32_I8/g" **/*.cpp
sed -i "s/MFMA_I8_32x32x16_I32/MFMA_I32_32x32x16_I8/g" **/*.mlir
sed -i "s/MFMA_I8_32x32x16_I32/MFMA_I32_32x32x16_I8/g" **/*.td
sed -i "s/MFMA_I8_32x32x16_I32/MFMA_I32_32x32x16_I8/g" **/*.h
sed -i "s/MFMA_I8_32x32x16_I32/MFMA_I32_32x32x16_I8/g" **/*.cpp
```

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Aug 6, 2024
1 parent f109f66 commit 82012e6
Show file tree
Hide file tree
Showing 30 changed files with 174 additions and 174 deletions.
4 changes: 2 additions & 2 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: chip = <wgp_count = 304>>

// GFX940: target = #iree_gpu.target<arch = "gfx940",
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],

// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func.func @contract_to_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<32x8xf16>, vector<8x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<32x32xf32>
Expand Down Expand Up @@ -128,7 +128,7 @@ func.func @contract_to_mfma_16x16x16_mm(%a : vector<16x16xf16>, %b : vector<16x1
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>

%O = iree_vector_ext.to_layout %output to #layout_b : vector<16x16xf32>
Expand Down Expand Up @@ -216,7 +216,7 @@ func.func @contract_to_mfma_32x32x8_mm_mnbatch(%a : vector<64x8xf16>, %b : vecto
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<64x8xf16>, vector<8x32xf16> into vector<64x32xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<64x32xf32>
Expand Down Expand Up @@ -305,7 +305,7 @@ func.func @contract_to_mfma_32x32x8_mm_kbatch(%a : vector<32x16xf16>, %b : vecto
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<32x32xf32>
Expand Down Expand Up @@ -388,7 +388,7 @@ func.func @contract_to_mfma_32x32x8_mm_mnbatch_order(%a : vector<64x8xf16>, %b :
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<64x8xf16>, vector<8x96xf16> into vector<64x96xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<64x96xf32>
Expand Down Expand Up @@ -479,7 +479,7 @@ func.func @contract_to_mfma_32x32x8_mmt(%a : vector<32x8xf16>, %b : vector<64x8x
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<32x8xf16>, vector<64x8xf16> into vector<32x64xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<32x64xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func.func @weight_dequant_matmul() {
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
func.func @conv() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
func.func @conv() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x34x34x1280xf16>>
Expand Down
94 changes: 47 additions & 47 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,19 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return OpaqueMmaLayout{16, 16, 4, f32, f32, f32};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return OpaqueMmaLayout{32, 32, 16, i8, i8, i32};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
Expand Down Expand Up @@ -277,7 +277,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand All @@ -295,7 +295,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]>
// #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]>
// #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX],
Expand All @@ -316,8 +316,8 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand All @@ -334,7 +334,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 8]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand Down Expand Up @@ -437,26 +437,26 @@ MMAAttr::getABCVectorTypes() const {
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({16}, getCType());
Expand Down Expand Up @@ -485,11 +485,11 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return 1;
Expand All @@ -502,11 +502,11 @@ int64_t MMAAttr::getBlockSize() const {
int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return 64;
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
Expand All @@ -524,20 +524,20 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 4}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 4}};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 8}};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 8}};
}
Expand All @@ -556,20 +556,20 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{8, 1}};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{8, 1}};
}
Expand All @@ -585,14 +585,14 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
Expand Down Expand Up @@ -632,11 +632,11 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
rhs, acc)
.getResult();
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
Expand Down Expand Up @@ -716,8 +716,8 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides(
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
break;
default:
return failure();
Expand Down
20 changes: 10 additions & 10 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,23 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>

// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>;
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 1>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 2>;
def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 3>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>;
def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 1>;
def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 2>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 3>;
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 4>;
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 5>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 6>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>;

def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F32_16x16x4_F32,
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
MFMA_F8E4M3FNUZ_16x16x32_F32,
MFMA_I8_16x16x32_I32,
MFMA_I8_32x32x16_I32,
MFMA_F32_16x16x16_F16,
MFMA_F32_32x32x8_F16,
MFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_I32_16x16x32_I8,
MFMA_I32_32x32x16_I8,
WMMA_F16_16x16x16_F32,
WMMA_F16_16x16x16_F16
]>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def IREEGPU_MultiMmaOp : Op<IREEGPU_Dialect, "multi_mma", [
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
: vector<2x3x4xf16>, vector<3x5x4xf16> into vector<2x5x4xf32>
Expand Down Expand Up @@ -99,7 +99,7 @@ def IREEGPU_MultiMmaOp : Op<IREEGPU_Dialect, "multi_mma", [
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = [],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
: vector<4xf16>, vector<4xf16> into vector<4xf32>
Expand Down Expand Up @@ -127,7 +127,7 @@ def IREEGPU_MultiMmaOp : Op<IREEGPU_Dialect, "multi_mma", [
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
rhs_permutation = [1, 0]
}
%7 = iree_gpu.multi_mma %4, %5, %6 #contraction_trait
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@

module {
func.func @test_mfma_f16_16x16x16_f32() attributes {
mma_types = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>} {
mma_types = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>} {
return
}
}
// CHECK-LABEL: func @test_mfma_f16_16x16x16_f32
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>

module {
func.func @test_mfma_f16_32x32x8_f32() attributes {
mma_types = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>} {
mma_types = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} {
return
}
}
// CHECK-LABEL: func @test_mfma_f16_32x32x8_f32
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>

module {
func.func @test_wmma_f16_16x16x16_f32() attributes {
Expand Down
Loading

0 comments on commit 82012e6

Please sign in to comment.