From e8ff88da4e1846ac248eaff5bc55ebdc728b483e Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 19 Jul 2024 09:59:00 +0000 Subject: [PATCH] Fixups --- .../Dialect/Vector/Transforms/VectorTransforms.h | 2 +- .../Vector/Transforms/VectorMaskElimination.cpp | 12 ++++++++++++ mlir/test/Dialect/Vector/eliminate-masks.mlir | 10 +++++----- .../test/lib/Dialect/Vector/TestVectorTransforms.cpp | 9 ++------- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h index 847f333d6a9310..e815e026305fab 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h @@ -116,7 +116,7 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter); -/// Structure to hold the range [vscaleMin, vscaleMax] `vector.vscale` can take. +// Structure to hold the range of `vector.vscale`. struct VscaleRange { unsigned vscaleMin; unsigned vscaleMax; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp index abec8c75b8fc91..486784a9cf102b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp @@ -1,3 +1,11 @@ +//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" @@ -105,10 +113,14 @@ void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, return; OpBuilder::InsertionGuard g(rewriter); + + // Build worklist so we can safely insert new ops in + // `resolveAllTrueCreateMaskOp()`. SmallVector worklist; function.walk([&](vector::CreateMaskOp createMaskOp) { worklist.push_back(createMaskOp); }); + rewriter.setInsertionPointToStart(&function.front()); for (auto mask : worklist) (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange); diff --git a/mlir/test/Dialect/Vector/eliminate-masks.mlir b/mlir/test/Dialect/Vector/eliminate-masks.mlir index 99c9a60a09facb..da02fed1efa7de 100644 --- a/mlir/test/Dialect/Vector/eliminate-masks.mlir +++ b/mlir/test/Dialect/Vector/eliminate-masks.mlir @@ -16,7 +16,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor %extracted_slice_0 = tensor.extract_slice %tensor[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x1000xf32> to tensor<1x?xf32> %output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice_0) -> tensor<1x?xf32> { // 1. Extract a slice. - %extracted_slice_1 = tensor.extract_slice %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor + %extracted_slice_1 = tensor.extract_slice %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor // 2. Create a mask for the slice. %dim_1 = tensor.dim %extracted_slice_1, %c0 : tensor @@ -57,8 +57,8 @@ func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) { %mask = vector.create_mask %dim : vector<[4]xi1> "test.some_use"(%mask) : (vector<[4]xi1>) -> () // !!! Here the size of the mask could shrink in the next iteration. - %next_num_els = affine.min affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale] - %new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_els] [1] : tensor<1000xf32> to tensor + %next_num_elts = affine.min affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale] + %new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_elts] [1] : tensor<1000xf32> to tensor scf.yield %new_extracted_slice : tensor } "test.some_use"(%slice) : (tensor) -> () @@ -110,8 +110,8 @@ func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32> %c4 = arith.constant 4 : index %vscale = vector.vscale %c4_vscale = arith.muli %vscale, %c4 : index - // This is _very_ simple but since addi is not a constant value bounds will - // be used to resolve it. + // This is _very_ simple but since tensor.dim is not a constant value bounds + // will be used to resolve it. %dim = tensor.dim %tensor, %c0 : tensor<2x?xf32> %mask = vector.create_mask %dim, %c4_vscale : vector<3x[4]xi1> "test.some_use"(%mask) : (vector<3x[4]xi1>) -> () diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f74ff2725f815e..69d8ec79407b5c 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -885,15 +885,10 @@ struct TestEliminateVectorMasks : PassWrapper(pass) {} Option vscaleMin{ - *this, "vscale-min", - llvm::cl::desc( - "Minimum value `vector.vscale` can possibly be at runtime."), + *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."), llvm::cl::init(1)}; - Option vscaleMax{ - *this, "vscale-max", - llvm::cl::desc( - "Maximum value `vector.vscale` can possibly be at runtime."), + *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."), llvm::cl::init(16)}; StringRef getArgument() const final { return "test-eliminate-vector-masks"; }