From 914858fb89c028a94564b590626aab519d611e54 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 20 Sep 2024 14:51:21 +0530 Subject: [PATCH] [VectorDistribution] Reuse intrinsic layout in chained gemm (#18505) This patch teaches attention codegen pipeline to reuse the intrinsic layout of output of the first matmul as the lhs of the second matmul. This is possible for 16x16x16 and 32x32x8 MFMA intrinsic layouts. --- .../compiler/Codegen/Common/GPU/BUILD.bazel | 1 + .../Codegen/Common/GPU/CMakeLists.txt | 1 + .../Common/GPU/GPUDistributionPatterns.cpp | 14 +- .../GPUNestedLayoutDistributionPatterns.cpp | 89 +++++++ .../LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp | 228 +++++++++++++++--- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 1 - .../pipeline_vector_distribute_gfx940.mlir | 76 +++++- 7 files changed, 364 insertions(+), 46 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 55e1b247dd39..296b316ce79c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -104,6 +104,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils", "//compiler/src/iree/compiler/Dialect/Encoding/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", "@llvm-project//mlir:AffineDialect", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index 4ded89f40fa0..e078969c7791 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -135,6 +135,7 @@ iree_cc_library( iree::compiler::Codegen::Utils::VectorOpUtils iree::compiler::Dialect::Encoding::IR iree::compiler::Dialect::HAL::IR + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp index ad0dbc2c2ee4..58ed23ac1fbf 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp @@ -1020,10 +1020,16 @@ struct DistributeTrivialLayoutConversions final PatternRewriter &rewriter) const override { auto input = cast(toLayoutOp.getInput()); auto output = cast(toLayoutOp.getOutput()); - VectorLayoutInterface currentLayout = - dyn_cast(signature[input]); - VectorLayoutInterface targetLayout = - dyn_cast(signature[output]); + VectorLayoutInterface currentLayout = signature[input]; + VectorLayoutInterface targetLayout = signature[output]; + + if (!currentLayout) { + return rewriter.notifyMatchFailure(toLayoutOp, "No layout set on input"); + } + + if (!targetLayout) { + return rewriter.notifyMatchFailure(toLayoutOp, "No layout set on output"); + } if (currentLayout != targetLayout) { return rewriter.notifyMatchFailure(toLayoutOp, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index 260f7c24b07c..ec86bf2cf058 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -10,6 +10,7 @@ #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h" #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Utils/Permutation.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" @@ -601,6 +602,93 @@ struct DistributeTranspose final : OpDistributionPattern { } }; +struct DistributeBatchOuterToLayoutConversions final + : OpDistributionPattern { + using OpDistributionPattern::OpDistributionPattern; + + LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp, + DistributionSignature &signature, + PatternRewriter &rewriter) const override { + Location loc = toLayoutOp.getLoc(); + auto input = cast(toLayoutOp.getInput()); + auto output = cast(toLayoutOp.getOutput()); + auto layoutA = dyn_cast(signature[input]); + auto layoutB = dyn_cast(signature[output]); + + if (!layoutA || !layoutB) { + return rewriter.notifyMatchFailure(toLayoutOp, "non-nested layout"); + } + + // Check if everything other than batch and outer tile matches. + if (layoutA.getSubgroupTile() != layoutB.getSubgroupTile()) { + return failure(); + } + if (layoutA.getSubgroupStrides() != layoutB.getSubgroupStrides()) { + return failure(); + } + if (layoutA.getThreadTile() != layoutB.getThreadTile()) { + return failure(); + } + if (layoutA.getThreadStrides() != layoutB.getThreadStrides()) { + return failure(); + } + if (layoutA.getElementTile() != layoutB.getElementTile()) { + return failure(); + } + + auto batchTileA = SmallVector(layoutA.getBatchTile()); + auto outerTileA = SmallVector(layoutA.getOuterTile()); + auto batchTileB = SmallVector(layoutB.getBatchTile()); + auto outerTileB = SmallVector(layoutB.getOuterTile()); + + // Check if there is a batch/outer tile mismatch. + if (batchTileA == batchTileB && outerTileA == outerTileB) { + return rewriter.notifyMatchFailure(toLayoutOp, + "trivial layout conversion"); + } + + SmallVector shapeA = layoutA.getDistributedShape(); + SmallVector shapeB = layoutB.getDistributedShape(); + int64_t rank = layoutA.getRank(); + + // Interleave batch and outer dims by transposing. + + // Build a permutation for interleaving. + auto interleavePermutation = + llvm::to_vector(llvm::seq(shapeA.size())); + for (int i = 0; i < rank; ++i) { + // Batch tile : [0...rank] + // OuterTile : [rank+1...2*rank] + // Interleave : [batch0, outer0, batch1, outer1,...] + interleavePermutation[2 * i] = i; + interleavePermutation[2 * i + 1] = i + rank; + } + + auto interleaved = rewriter.create( + loc, getDistributed(rewriter, input, layoutA), interleavePermutation); + + // Shape cast to match the new layout. + + SmallVector transposedShapeB(shapeB); + applyPermutationToVector(transposedShapeB, interleavePermutation); + Type reshapedType = VectorType::get( + transposedShapeB, interleaved.getResultVectorType().getElementType()); + + auto reshaped = + rewriter.create(loc, reshapedType, interleaved); + + // Inverse transpose to preserve original order. + SmallVector invertedPermutation = + invertPermutationVector(interleavePermutation); + + auto layouted = rewriter.create(loc, reshaped, + invertedPermutation); + + replaceOpWithDistributedValues(rewriter, toLayoutOp, layouted.getResult()); + return success(); + } +}; + } // namespace void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns, @@ -612,6 +700,7 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns, patterns.add(patterns.getContext()); patterns.add(patterns.getContext(), subgroupSize, maxBitsPerShuffle); + patterns.add(patterns.getContext()); } }; // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp index bdc31eb386dc..3f84454268dc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp @@ -23,9 +23,11 @@ namespace mlir::iree_compiler { namespace { -LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, - RewriterBase &rewriter, - linalg::LinalgOp contract) { +static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, + RewriterBase &rewriter, + linalg::LinalgOp contract, + bool promoteLhs = true, + bool promoteRhs = true) { // TODO: Add SIMT fallback. if (!schedule) { return contract->emitError("missing mma schedule for contraction"); @@ -65,8 +67,13 @@ LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, // TODO: We should read this from the lowering_config on the operation. // TODO: This is a hack until layout analysis is improved. The layout analysis // should decide where to put these shared memory conversions. - layoutedLhs.setSharedMemoryConversion(true); - layoutedRhs.setSharedMemoryConversion(true); + if (promoteLhs) { + layoutedLhs.setSharedMemoryConversion(true); + } + + if (promoteRhs) { + layoutedRhs.setSharedMemoryConversion(true); + } contract->setOperand(0, layoutedLhs.getResult()); contract->setOperand(1, layoutedRhs.getResult()); @@ -82,9 +89,9 @@ LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, return success(); } -LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule, - RewriterBase &rewriter, - linalg::LinalgOp conv) { +static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule, + RewriterBase &rewriter, + linalg::LinalgOp conv) { // TODO: Add SIMT fallback. if (!schedule) { return conv->emitError("missing mma schedule for convolution"); @@ -160,35 +167,164 @@ LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule, return success(); } -LogicalResult setAttentionMatmulAnchor(IREE::GPU::MMAScheduleAttr schedule, - RewriterBase &rewriter, - linalg::LinalgOp contract) { +/// Let's assume we have an matmul intrinsic (@) doing a matmul +/// ((M, K) X (K, N)) which produces a particular layout: +/// +/// C = A @ B +/// +/// If we transpose and swap the operands, we can keep the same matmul +/// intrinsic, but transpose the layout of the output intrinsic: +/// +/// A.T = transpose(A) +/// B.T = transpose(B) +/// C.T = B.T @ A.T +/// C = transpose(C.T) +/// +/// This is useful when the "@" instruction that the hardware lowers to +/// has a specific thread layout but the further uses of C expects a transposed +/// layout to the produced layout. +/// +/// For example, for "@" lowering to AMDGPU MFMA instructions, the operands +/// have layout L and L.T and the result has the layout L.T . +/// So if you have a chain of matmuls: +/// +/// C (L.T) = A (L) @ B (L.T) +/// E (L.T) = C (L.T) @ D (L.T) +/// ^^^^^^^ +/// Expected layout by instruction is L +/// +/// To fix this, we can apply this transformation on the first matrix: +/// +/// C.T (L.T) = B.T (L) @ A (L.T) +/// C (L) = transpose C.T (L.T) +/// E (L.T) = C (L) @ D (L.T) +/// ^^^^^ +/// Layout matches the instruction! +/// +/// Note that the mathematical formula +/// C = A @ B --> C.T = B.T @ A.T +/// is only defined on standard "@" function, it may be a different +/// transformation for other indexing maps. +/// +/// For linalg operands, since the indexing maps are part of the op defination, +/// we can achieve the same transformation by simply swapping the operands. +static void swapOperandsToTransposeIntrinsic(RewriterBase &rewriter, + linalg::GenericOp contractOp) { + Value lhs = contractOp->getOperand(0); + Value rhs = contractOp->getOperand(1); + + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + std::swap(indexingMaps[0], indexingMaps[1]); + + contractOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(indexingMaps)); + contractOp->setOperand(0, rhs); + contractOp->setOperand(1, lhs); +} + +static IREE::GPU::MMAScheduleAttr +transposeSchedule(RewriterBase &rewriter, IREE::GPU::MMAScheduleAttr schedule) { + return rewriter.getAttr( + schedule.getIntrinsic(), schedule.getSubgroupNCount(), + schedule.getSubgroupMCount()); +} + +static LogicalResult +setAttentionMatmulAnchor(IREE::GPU::MMAScheduleAttr schedule, + RewriterBase &rewriter, linalg::LinalgOp qkMatmul, + linalg::LinalgOp pvMatmul) { // TODO: Add SIMT fallback. if (!schedule) { - return contract->emitError("missing mma schedule for contraction"); + return pvMatmul->emitError("missing mma schedule for contraction"); } - if (contract->hasAttr("attention_qk_matmul")) { - // subgroup_n count for attention matmul is always 1, because it is the - // reduction dimension. The subgroup_n count is in reality, for the second - // matmul. - IREE::GPU::MMAScheduleAttr qkSchedule = - rewriter.getAttr( - schedule.getIntrinsic(), - /*subgroup_m_count=*/schedule.getSubgroupMCount(), - /*subgroup_n_count=*/1); - return setContractionAnchor(qkSchedule, rewriter, contract); + // Check if the intrinsic output for qkMatmul can be reused for pvMatmul. + // We know that pvMatmul takes result of qkMatmul as it's lhs. + // If the intrinsic output of pvMatmul can be used as rhs of pvMatmul, + // we swap operands of both contracts to get output as transposed intrinsic. + bool reuseIntrinsicOutput = false; + bool transposeIntrinsic = false; + + auto intrinsic = cast(schedule.getIntrinsic()); + IREE::GPU::MMASingleSubgroupLayout lhsLayout = + intrinsic.getASingleSubgroupLayout(); + IREE::GPU::MMASingleSubgroupLayout rhsLayout = + intrinsic.getBSingleSubgroupLayout(); + IREE::GPU::MMASingleSubgroupLayout outLayout = + intrinsic.getCSingleSubgroupLayout(); + + auto matchLayout = [](IREE::GPU::MMASingleSubgroupLayout layoutA, + IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool { + return (layoutA.element == layoutB.element) && + (layoutA.thread == layoutB.thread) && + (layoutA.tstrides == layoutB.tstrides); + }; + + // TODO: Move this check to KernelConfig and set appropriate attributes + // in lowering_config for the operation. This allows us to check shared + // memory usage and decide what kind of pipelining we can do. + if (matchLayout(outLayout, lhsLayout)) { + reuseIntrinsicOutput = true; + } else if (matchLayout(outLayout, rhsLayout)) { + reuseIntrinsicOutput = true; + transposeIntrinsic = true; } - if (contract->hasAttr("attention_pv_matmul")) { - // subgroup_n count for attention matmul is always 1, because it is the - // reduction dimension. The subgroup_n count is in reality, for the second - // matmul. - return setContractionAnchor(schedule, rewriter, contract); + // subgroup_n count for attention matmul is always 1, because it is the + // reduction dimension. The subgroup_n count is in reality, for the pvMatmul. + IREE::GPU::MMAScheduleAttr qkSchedule = + rewriter.getAttr( + schedule.getIntrinsic(), + /*subgroup_m_count=*/schedule.getSubgroupMCount(), + /*subgroup_n_count=*/1); + IREE::GPU::MMAScheduleAttr pvSchedule = schedule; + + // Transpose the intrinsic if requested. See docs for + // swapOperandsToTransposeIntrinsic for more information on why this is done. + if (transposeIntrinsic) { + auto qkGeneric = dyn_cast(qkMatmul.getOperation()); + auto pvGeneric = dyn_cast(pvMatmul.getOperation()); + if (!qkGeneric || !pvGeneric) { + pvMatmul->emitOpError("Non generic qkMatmul/pvMatmul transpose intrinsic " + "not yet implemented"); + return failure(); + } + swapOperandsToTransposeIntrinsic(rewriter, qkGeneric); + swapOperandsToTransposeIntrinsic(rewriter, pvGeneric); + qkSchedule = transposeSchedule(rewriter, qkSchedule); + pvSchedule = transposeSchedule(rewriter, pvSchedule); } - return contract->emitError("attention matmul should have either " - "attention_qk_matmul or attention_pv_matmul set"); + if (failed(setContractionAnchor(qkSchedule, rewriter, qkMatmul))) { + return failure(); + } + + // Do not promote lhs of pvMatmul if we are reusing the intrinsic output. + bool promoteLhs = !reuseIntrinsicOutput; + bool promoteRhs = true; + if (transposeIntrinsic) { + std::swap(promoteLhs, promoteRhs); + } + + return setContractionAnchor(pvSchedule, rewriter, pvMatmul, promoteLhs, + promoteRhs); +} + +static Operation *getOpWithAttr(Operation *root, StringRef attr) { + Operation *result = nullptr; + WalkResult walkResult = root->walk([&](Operation *op) { + if (op->hasAttr(attr)) { + if (result) { + return WalkResult::interrupt(); + } + result = op; + } + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) { + return nullptr; + } + return result; } struct LLVMGPUConfigureTensorLayoutsPass final @@ -212,19 +348,33 @@ struct LLVMGPUConfigureTensorLayoutsPass final // now, layout setting for other problems like reductions is TODO. SmallVector contracts; SmallVector convs; - SmallVector attentionMatmuls; + + auto attentionQKMatmul = dyn_cast_or_null( + getOpWithAttr(func, "attention_qk_matmul")); + auto attentionPVMatmul = dyn_cast_or_null( + getOpWithAttr(func, "attention_pv_matmul")); + + if (attentionQKMatmul && !attentionPVMatmul) { + func->emitError("Expected attention attributes to be set properly"); + return signalPassFailure(); + } + + if (!attentionQKMatmul && attentionPVMatmul) { + func->emitError("Expected attention attributes to be set properly"); + return signalPassFailure(); + } func->walk([&](linalg::LinalgOp linalgOp) { + if (linalgOp == attentionQKMatmul || linalgOp == attentionPVMatmul) { + return WalkResult::advance(); + } + if (linalg::isaContractionOpInterface(linalgOp)) { - if (linalgOp->hasAttr("attention_qk_matmul") || - linalgOp->hasAttr("attention_pv_matmul")) { - attentionMatmuls.push_back(linalgOp); - } else { - contracts.push_back(linalgOp); - } + contracts.push_back(linalgOp); } else if (succeeded(linalg::inferConvolutionDims(linalgOp))) { convs.push_back(linalgOp); } + return WalkResult::advance(); }); IRRewriter rewriter(func); @@ -241,9 +391,9 @@ struct LLVMGPUConfigureTensorLayoutsPass final } } - for (linalg::LinalgOp attentionMatmul : attentionMatmuls) { - if (failed(setAttentionMatmulAnchor(scheduleAttr, rewriter, - attentionMatmul))) { + if (attentionQKMatmul && attentionPVMatmul) { + if (failed(setAttentionMatmulAnchor( + scheduleAttr, rewriter, attentionQKMatmul, attentionPVMatmul))) { return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index faec5386ecc1..cbe19e53ea93 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -868,7 +868,6 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, // Preprocessing for vector distribution. funcPassManager.addPass(createLLVMGPUCastTypeToFitMMAPass()); - funcPassManager.addPass(createAMDGPUPrepareForChainedMatmulPass()); // Vector SIMD -> Vector SIMT funcPassManager.addPass(createLLVMGPUConfigureVectorLayoutsPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index 86e7f0b15242..eb8f4f177396 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -3,6 +3,11 @@ // RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \ // RUN: %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 \ +// RUN: --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \ +// RUN: %s | FileCheck %s --check-prefix=MEMORY + #config = #iree_codegen.lowering_config #translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> @@ -591,10 +596,16 @@ hal.executable private @attention_20x4096x64x4096x64 { // CHECK: transfer_read // CHECK: scf.for %{{.*}} = %c0 to %c4096 step %c64 -// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x4x1x1x4x1xf32>) +// CHECK-SAME: -> (vector<2x1x1xf32>, vector<2x1x1xf32>, vector<2x4x1x1x1x4xf32>) // CHECK-COUNT-48: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> // CHECK: scf.yield +// Check that we only use alloc for Q, K, and V. No shared memory for S is +// needed because the intrinsic layout mathes. +// MEMORY-LABEL: func.func @attention_20x4096x64x4096x64() +// MEMORY-COUNT-3: memref.alloc +// MEMORY-NOT: memref.alloc + // ----- #config = #iree_codegen.lowering_config @@ -640,6 +651,67 @@ hal.executable private @attention_multiple_m_transpose { // CHECK-LABEL: func.func @attention_multiple_m_transpose() // CHECK: scf.for %{{.*}} = %c0 to %c72 step %c1 -// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>) +// CHECK-SAME: -> (vector<2x1x1xf32>, vector<2x1x1xf32>, vector<2x8x1x1x1x4xf32>) // CHECK-COUNT-96: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> // CHECK: scf.yield + +// Check that we only use alloc for Q, K, and V. No shared memory for S is +// needed because the intrinsic layout mathes. +// MEMORY-LABEL: func.func @attention_multiple_m_transpose() +// MEMORY-COUNT-3: memref.alloc +// MEMORY-NOT: memref.alloc + +// ----- + +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, subgroup_m_count = 4, subgroup_n_count = 1>}> + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable private @attention_mfma_32x32x8 { + hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export public @attention_mfma_32x32x8 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @attention_mfma_32x32x8() attributes {translation_info = #translation} { + %cst = arith.constant 1.0 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [24, 64, 4608, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x64x4608x128xf16> + %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> + %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> + %7 = tensor.empty() : tensor<64x4608x24x128xf16> + %8 = tensor.empty() : tensor<24x64x4608x128xf16> + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<64x4608x24x128xf16> + flow.dispatch.tensor.store %10, %3, offsets = [0, 0, 0, 0], sizes = [64, 4608, 24, 128], strides = [1, 1, 1, 1] : tensor<64x4608x24x128xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-LABEL: func.func @attention_mfma_32x32x8() +// CHECK: scf.for %{{.*}} = %c0 to %c144 step %c1 +// CHECK-SAME: -> (vector<1x1x1xf32>, vector<1x1x1xf32>, vector<1x4x1x4x1x4xf32>) +// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32> +// CHECK: scf.yield + +// Check that we only use alloc for Q, K, and V. No shared memory for S is +// needed because the intrinsic layout mathes. +// MEMORY-LABEL: func.func @attention_mfma_32x32x8() +// MEMORY-COUNT-3: memref.alloc +// MEMORY-NOT: memref.alloc