From a04c262e54816a3e7d34e76df76510ce6324803e Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 15 Feb 2023 23:10:51 -0500 Subject: [PATCH] [spirv] Vectorize integer extend ops in lowering to subgroup_mma (#12202) For integer types, integer extend ops are matched against neighboring vector.transfer_read/contract ops when lowering to mma ops. This enables vectorizing the extend ops to cooperative matrix sizes. This also enables support for cases with mixed signedness. Depends on https://reviews.llvm.org/D143922 --- .../SPIRVTileAndVectorizeToCooperativeOps.cpp | 33 +++- ...tile_and_vectorize_to_cooperative_ops.mlir | 161 ++++++++++++++++++ 2 files changed, 193 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp index 83b25efdf559..0ef3918ecbef 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp @@ -148,6 +148,25 @@ void populateVectorizationPatterns(MLIRContext *context, vector::populateVectorReductionToContractPatterns(patterns); } +template +Optional> getExtOpVectorShape( + ExtOpTy op, ArrayRef nativeShape) { + auto insert = + op.getOperand().template getDefiningOp(); + if (!insert) return std::nullopt; + + VectorType sliceType = insert.getSourceVectorType(); + for (Operation *users : op->getUsers()) { + auto extract = dyn_cast(users); + if (!extract) return std::nullopt; + auto vecType = extract.getResult().getType().cast(); + if (!llvm::equal(sliceType.getShape(), vecType.getShape())) + return std::nullopt; + } + + return llvm::to_vector<>(sliceType.getShape()); +} + /// Returns vector shape matching native cooperative op sizes for unrolling /// high-D vectors. Optional> getCooperativeOpVectorShape( @@ -186,8 +205,15 @@ Optional> getCooperativeOpVectorShape( } if (auto readOp = dyn_cast(op)) { + auto sourceOp = op; + if (op->hasOneUse()) { + auto user = *op->user_begin(); + if (isa(user) || isa(user)) + sourceOp = user; + } + VectorType sliceType; - for (Operation *users : op->getUsers()) { + for (Operation *users : sourceOp->getUsers()) { auto extract = dyn_cast(users); if (!extract) return std::nullopt; auto vecType = extract.getResult().getType().cast(); @@ -197,6 +223,11 @@ Optional> getCooperativeOpVectorShape( return llvm::to_vector<>(sliceType.getShape()); } + if (auto extOp = dyn_cast(op)) + return getExtOpVectorShape(extOp, nativeShape); + if (auto extOp = dyn_cast(op)) + return getExtOpVectorShape(extOp, nativeShape); + return std::nullopt; } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir index bfda15057d4e..54f3b700c47e 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir @@ -341,3 +341,164 @@ hal.executable public @matmul_256x1024x128_div_add { // CHECK: %[[READ6:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]], %[[C0]]] // CHECK: %[[DIV:.+]] = arith.divf %[[READ6]], %[[READ5]] : vector<1x16x16xf16> // CHECK: vector.transfer_write %[[DIV]], %{{.+}}[%[[C0]], %[[C0]], %[[C0]]] + +// ----- + +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer>, + #hal.descriptor_set.binding<3, storage_buffer>, + #hal.descriptor_set.binding<4, storage_buffer> + ]> +]> +hal.executable public @matmul_256x1024x128_mixed_signedness_int8 { + hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env< + #spirv.vce, NVIDIA:DiscreteGPU, + #spirv.resource_limits< + cooperative_matrix_properties_nv = [ + #spirv.coop_matrix_props< + a_type = i8, b_type = i8, c_type = i32, k_size = 32, + m_size = 8, n_size = 8, result_type = i32, scope = >, + #spirv.coop_matrix_props< + a_type = f16, b_type = f16, c_type = f16, k_size = 16, + m_size = 16, n_size = 16, result_type = f16, scope = >, + #spirv.coop_matrix_props< + a_type = f16, b_type = f16, c_type = f32, k_size = 16, + m_size = 16, n_size = 16, result_type = f32, scope = > + ], + max_compute_shared_memory_size = 49152, + max_compute_workgroup_invocations = 1024, + max_compute_workgroup_size = [2147483647, 65535, 65535], + subgroup_size = 32> + >}> { + hal.executable.export public @matmul_256x1024x128_mixed_signedness_int8 layout(#pipeline_layout) attributes { + translation_info = #translation, + workgroup_size = [32 : index, 1 : index, 1 : index] + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors + %c1 = arith.constant 1 : index + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%arg0] + %1 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%arg1] + hal.return %0, %1, %c1 : index, index, index + } + builtin.module { + func.func @matmul_256x1024x128_mixed_signedness_int8() { + %cst = arith.constant 0 : i32 + %cst_i8 = arith.constant 0 : i8 + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %0 = gpu.thread_id x + %1 = gpu.thread_id y + %2 = gpu.thread_id z + %alloc = memref.alloc() : memref<32x32xi8, 3> + %alloc_0 = memref.alloc() : memref<32x32xi8, 3> + %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<256x1024xi8> + %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<1024x128xi8> + %7 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : memref<256x128xi32> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y] + %9 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x] + %subview = memref.subview %7[%8, %9] [32, 32] [1, 1] : memref<256x128xi32> to memref<32x32xi32, strided<[128, 1], offset: ?>> + %subview_1 = memref.subview %3[%8, 0] [32, 1024] [1, 1] : memref<256x1024xi8> to memref<32x1024xi8, strided<[1024, 1], offset: ?>> + %subview_2 = memref.subview %4[0, %9] [1024, 32] [1, 1] : memref<1024x128xi8> to memref<1024x32xi8, strided<[128, 1], offset: ?>> + linalg.fill {__internal_linalg_transform__ = "workgroup_memory"} ins(%cst : i32) outs(%subview : memref<32x32xi32, strided<[128, 1], offset: ?>>) + scf.for %arg0 = %c0 to %c1024 step %c32 { + %subview_5 = memref.subview %subview_1[0, %arg0] [32, 32] [1, 1] : memref<32x1024xi8, strided<[1024, 1], offset: ?>> to memref<32x32xi8, strided<[1024, 1], offset: ?>> + %subview_6 = memref.subview %subview_2[%arg0, 0] [32, 32] [1, 1] : memref<1024x32xi8, strided<[128, 1], offset: ?>> to memref<32x32xi8, strided<[128, 1], offset: ?>> + gpu.barrier + %subview_7 = memref.subview %alloc[%c0, %c0] [32, 32] [1, 1] : memref<32x32xi8, 3> to memref<32x32xi8, strided<[32, 1], offset: ?>, 3> + %10 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2] + %11 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0] + %12 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2] + %13 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0] + %subview_8 = memref.subview %subview_5[%10, %11] [1, 8] [1, 1] : memref<32x32xi8, strided<[1024, 1], offset: ?>> to memref<1x8xi8, strided<[1024, 1], offset: ?>> + %subview_9 = memref.subview %subview_7[%12, %13] [1, 8] [1, 1] : memref<32x32xi8, strided<[32, 1], offset: ?>, 3> to memref<1x8xi8, strided<[32, 1], offset: ?>, 3> + %14 = vector.transfer_read %subview_8[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<1x8xi8, strided<[1024, 1], offset: ?>>, vector<1x8xi8> + vector.transfer_write %14, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xi8>, memref<1x8xi8, strided<[32, 1], offset: ?>, 3> + %subview_10 = memref.subview %alloc_0[%c0, %c0] [32, 32] [1, 1] : memref<32x32xi8, 3> to memref<32x32xi8, strided<[32, 1], offset: ?>, 3> + %15 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2] + %16 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0] + %17 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2] + %18 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0] + %subview_11 = memref.subview %subview_6[%15, %16] [1, 8] [1, 1] : memref<32x32xi8, strided<[128, 1], offset: ?>> to memref<1x8xi8, strided<[128, 1], offset: ?>> + %subview_12 = memref.subview %subview_10[%17, %18] [1, 8] [1, 1] : memref<32x32xi8, strided<[32, 1], offset: ?>, 3> to memref<1x8xi8, strided<[32, 1], offset: ?>, 3> + %19 = vector.transfer_read %subview_11[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<1x8xi8, strided<[128, 1], offset: ?>>, vector<1x8xi8> + vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xi8>, memref<1x8xi8, strided<[32, 1], offset: ?>, 3> + gpu.barrier + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"] + } + ins(%alloc, %alloc_0 : memref<32x32xi8, 3>, memref<32x32xi8, 3>) outs(%subview : memref<32x32xi32, strided<[128, 1], offset: ?>>) + attrs = {__internal_linalg_transform__ = "workgroup_memory", lowering_config = #config} { + ^bb0(%in: i8, %in_5: i8, %out: i32): + %20 = arith.extui %in : i8 to i32 + %21 = arith.extsi %in_5 : i8 to i32 + %22 = arith.muli %20, %21 : i32 + %23 = arith.addi %22, %out : i32 + linalg.yield %23 : i32 + } + } + return + } + } + } +} + +// CHECK: #[[$MAP_Y:.+]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK: #[[$MAP_X:.+]] = affine_map<()[s0] -> ((s0 floordiv 32) * 16)> + +// CHECK-LABEL: func.func @matmul_256x1024x128_mixed_signedness_int8() + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0> : vector<16x16xi32> + +// CHECK-DAG: %[[ID_X:.+]] = gpu.thread_id x +// CHECK-DAG: %[[ID_Y:.+]] = gpu.thread_id y + +// CHECK-DAG: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<32x32xi8, 3> +// CHECK-DAG: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<32x32xi8, 3> + +// CHECK: %[[OFFSET_Y:.+]] = affine.apply #[[$MAP_Y]]()[%[[ID_Y]]] +// CHECK: %[[OFFSET_X:.+]] = affine.apply #[[$MAP_X]]()[%[[ID_X]]] + +// CHECK: scf.for %{{.+}} = %[[OFFSET_Y]] to %[[C32]] step %[[C32]] +// CHECK: scf.for %{{.+}} = %[[OFFSET_X]] to %[[C32]] step %[[C32]] +// CHECK: vector.transfer_write %[[ZERO]], {{.+}} : vector<16x16xi32>, memref<16x16xi32, strided<[128, 1], offset: ?>> +// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C1024]] step %[[C32]] +// CHECK: gpu.barrier +// CHECK: vector.transfer_read {{.+}} vector<1x8xi8> +// CHECK: vector.transfer_write +// CHECK: vector.transfer_read {{.+}} vector<1x8xi8> +// CHECK: vector.transfer_write +// CHECK: gpu.barrier +// CHECK: scf.for %[[IV_Y:.+]] = %[[OFFSET_Y]] to %[[C32]] step %[[C32]] +// CHECK: %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][%[[IV_Y]], 0] +// CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]] +// CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][0, %[[IV_X]]] +// CHECK: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]]] +// CHECK: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C16]]] +// CHECK: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]]] +// CHECK: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C16]], %[[C0]]] +// CHECK: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]]] +// CHECK: %[[EXTUI0:.+]] = arith.extui %[[READ0]] : vector<16x16xi8> to vector<16x16xi32> +// CHECK: %[[EXTUI1:.+]] = arith.extui %[[READ1]] : vector<16x16xi8> to vector<16x16xi32> +// CHECK: %[[EXTSI0:.+]] = arith.extsi %[[READ2]] : vector<16x16xi8> to vector<16x16xi32> +// CHECK: %[[EXTSI1:.+]] = arith.extsi %[[READ3]] : vector<16x16xi8> to vector<16x16xi32> +// CHECK: %[[CT0:.+]] = vector.contract +// CHECK-SAME: %[[EXTUI0]], %[[EXTSI0]], %[[READ4]] : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32> +// CHECK: %[[CT1:.+]] = vector.contract +// CHECK-SAME: %[[EXTUI1]], %[[EXTSI1]], %[[CT0]] : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32> +// CHECK: vector.transfer_write %[[CT1]], %{{.+}}[%[[C0]], %[[C0]]]