Skip to content

Commit

Permalink
[LLVMGPU] Pad to intrinsic shape in LLVMGPUPadAndVectorDistribute pip…
Browse files Browse the repository at this point in the history
…eline (#18632)

This patch makes LLVMGPUPromoteToFitMMA pass pad to a multiple of
intrinsic shape, instead of padding to 1.

Fixes #18602
  • Loading branch information
Groverkss authored Oct 8, 2024
1 parent 6001f9c commit ad68964
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,17 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
}

void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op,
utils::IteratorType targetIterType, bool nofold) const {
ArrayRef<int64_t> paddingDims,
ArrayRef<int64_t> padToMultipleOf, bool noFold) const {
assert(paddingDims.size() == padToMultipleOf.size() &&
"invalid pad multiples for padding dimensions");

LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);

SmallVector<int64_t> paddingDims;
for (auto [index, iterType] : llvm::enumerate(op.getIteratorTypesArray())) {
if (iterType == targetIterType) {
paddingDims.push_back(index);
}
}

SmallVector<bool> packPaddings(op.getNumDpsInputs(), nofold);
SmallVector<bool> packPaddings(op.getNumDpsInputs(), noFold);

// One is enough because they will essentially be padded to corresponding
// tile sizes, which should be multiple of MMA shapes.
SmallVector<int64_t> padToMultipleOf(paddingDims.size(), 1);
SmallVector<Attribute> paddingValueAttributes;
for (auto &operand : op->getOpOperands()) {
auto elemType = getElementTypeOrSelf(operand.get().getType());
Expand Down Expand Up @@ -80,18 +74,18 @@ class LLVMGPUPromoteMatmulToFitMMAPass final

// Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
// we can kick canonicalization patterns to fold outer tensor.pad ops away.
bool nofold = false;
bool noFold = false;
utils::IteratorType targetIterType = utils::IteratorType::parallel;
switch (targetDimensions) {
case LLVMGPUMatmulPadOption::ParallelDims:
LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n");
targetIterType = utils::IteratorType::parallel;
nofold = false;
noFold = false;
break;
case LLVMGPUMatmulPadOption::ReductionDims:
LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n");
targetIterType = utils::IteratorType::reduction;
nofold = true;
noFold = true;
break;
default: // Unreachable.
assert(false);
Expand All @@ -106,8 +100,47 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
});

IRRewriter rewriter(ctx);
for (auto op : candidates) {
padWithZeroValue(rewriter, op, targetIterType, nofold);
for (linalg::LinalgOp op : candidates) {
SmallVector<int64_t> padMultiples(op.getNumLoops(), 1);
auto config = dyn_cast_or_null<IREE::GPU::LoweringConfigAttr>(
getLoweringConfig(op));
if (config) {
switch (targetDimensions) {
case LLVMGPUMatmulPadOption::ParallelDims:
padMultiples = config.getStaticTilingLevelSizes(
static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
break;
case LLVMGPUMatmulPadOption::ReductionDims:
padMultiples = config.getStaticTilingLevelSizes(
static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
break;
default:
assert(false && "Unexpected target dimensions");
break;
}
}

// Populate padding dimensions.
SmallVector<int64_t> paddingDimensions;
for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) {
if (iter == targetIterType) {
paddingDimensions.push_back(idx);
}
}

// Populate tile sizes. We pad to multiples of workgroup/reduction
// tile sizes based on the selected target tiling dimensions.
// This pass is ran after the select target tiling is done to pad
// all dimensions to the select tile sizes.
SmallVector<int64_t> padToMultipleOf;
for (int64_t dim : paddingDimensions) {
if (padMultiples[dim] != 0) {
padToMultipleOf.push_back(padMultiples[dim]);
}
}

padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf,
noFold);
}

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,67 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {

// -----

#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 32, 0], reduction = [0, 0, 0, 8]}>
#translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, subgroup_m_count = 1, subgroup_n_count = 2>}>

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>

hal.executable public @pad_batch_matmul {
hal.executable.variant public @rocm_hsaco_fb target(#hal.executable.target<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @pad_batch_matmul ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @pad_batch_matmul() attributes {translation_info = #translation} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<196x16x24xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<196x24x24xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<196x16x24xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [196, 16, 24], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<196x16x24xf32>> -> tensor<196x16x24xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [196, 24, 24], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<196x24x24xf32>> -> tensor<196x24x24xf32>
%5 = tensor.empty() : tensor<196x16x24xf32>
%6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<196x16x24xf32>) -> tensor<196x16x24xf32>
%7 = linalg.batch_matmul {lowering_config = #config} ins(%3, %4 : tensor<196x16x24xf32>, tensor<196x24x24xf32>) outs(%6 : tensor<196x16x24xf32>) -> tensor<196x16x24xf32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [196, 16, 24], strides = [1, 1, 1] : tensor<196x16x24xf32> -> !flow.dispatch.tensor<writeonly:tensor<196x16x24xf32>>
return
}
}
}
}

// This test checks if we can handle an unaligned batch matmul which has sizes
// smaller than the chosen tile sizes. We just want to make sure we can compile
// this example. We also check if the correct transfer_read/transfer_write are
// produced with in_bounds attrs for the padded dimensions.

// CHECK-LABEL: @pad_batch_matmul
// CHECK: scf.for
// LHS
// CHECK: vector.transfer_read
// CHECK-SAME: in_bounds = [true, true, true]
// CHECK-SAME: memref<196x16x24xf32
// CHECK-SAME: vector<1x1x1xf32>
// RHS
// CHECK: vector.transfer_read
// CHECK-SAME: in_bounds = [true, true, false]
// CHECK-SAME: memref<1x8x24xf32
// CHECK-SAME: vector<1x1x2xf32>
// CHECK: scf.yield
// OUTPUT
// CHECK: vector.transfer_write
// CHECK-SAME: in_bounds = [true, true, false]
// CHECK-SAME: vector<1x4x1xf32>
// CHECK-SAME: memref<1x16x24xf32

// -----

// This test ensures that we are generating contraction schedules does not only work on contraction,
// but also will be compatible with transfer_read layouts anchors.
// Currently the transfer_read layout anchors expects WorkgroupSize % (WgTileSize / numelPerThread) == 0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)>
#map4 = affine_map<()[s0] -> (-s0 + 64)>
#map5 = affine_map<()[s0] -> (-s0 + 128)>
#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 16, 0], reduction = [0, 0, 0, 16]}>
func.func @batch_matmul_f16() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
Expand All @@ -29,7 +30,7 @@ func.func @batch_matmul_f16() {
%8 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>> -> tensor<1x?x1281xf16>
%9 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>> -> tensor<1x1281x?xf16>
%10 = linalg.fill ins(%cst : f16) outs(%7 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x?x1281xf16>, tensor<1x1281x?xf16>) outs(%10 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
%11 = linalg.batch_matmul {lowering_config = #config} ins(%8, %9 : tensor<1x?x1281xf16>, tensor<1x1281x?xf16>) outs(%10 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
flow.dispatch.tensor.store %11, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
return
}
Expand All @@ -48,14 +49,14 @@ func.func @batch_matmul_f16() {
// PARALLEL-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
// PARALLEL-SAME: outs(%[[FILL]]

// The reduction dim is not tiled in the test case, so it pads it to the same
// shape.
// The reduction dim is not tiled in the test case, so it pads it to the
// matmul intrinsic k.
// REDUCTION-DAG: %[[FILL_DEST:.+]] = flow.dispatch.tensor.load %[[OUT_HANDLE]]
// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[FILL_DEST]]
// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
// REDUCTION: } : tensor<1x?x1281xf16> to tensor<1x?x1281xf16>
// REDUCTION: } : tensor<1x?x1281xf16> to tensor<1x?x1296xf16>
// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
// REDUCTION: } : tensor<1x1281x?xf16> to tensor<1x1281x?xf16>
// REDUCTION: } : tensor<1x1281x?xf16> to tensor<1x1296x?xf16>
// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul
// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
// REDUCTION-SAME: outs(%[[FILL]]
Expand Down

0 comments on commit ad68964

Please sign in to comment.