From f874f02c570b765df13552133530070ae7032c67 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Tue, 30 Jul 2024 10:13:38 +0000 Subject: [PATCH] [mlir][vector] Use `DenseI64ArrayAttr` for shuffle masks Follow on from #100997. This again removes from boilerplate conversions to/from IntegerAttr and int64_t (otherwise, this is a NFC). --- .../mlir/Dialect/Vector/IR/VectorOps.td | 11 +++-- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 8 ++-- .../VectorToSPIRV/VectorToSPIRV.cpp | 6 +-- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 42 ++++++++----------- ...sertExtractStridedSliceRewritePatterns.cpp | 3 +- .../Vector/Transforms/VectorLinearize.cpp | 22 ++++------ 6 files changed, 36 insertions(+), 56 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3cdbd218745675..434ff3956c2501 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -421,7 +421,7 @@ def Vector_ShuffleOp : TCresVTEtIsSameAsOpBase<0, 1>>, InferTypeOpAdaptor]>, Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2, - I64ArrayAttr:$mask)>, + DenseI64ArrayAttr:$mask)>, Results<(outs AnyVector:$vector)> { let summary = "shuffle operation"; let description = [{ @@ -459,11 +459,7 @@ def Vector_ShuffleOp : : vector, vector ; yields vector<2xf32> ``` }]; - let builders = [ - OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef")> - ]; - let hasFolder = 1; - let hasCanonicalizer = 1; + let extraClassDeclaration = [{ VectorType getV1VectorType() { return ::llvm::cast(getV1().getType()); @@ -475,7 +471,10 @@ def Vector_ShuffleOp : return ::llvm::cast(getVector().getType()); } }]; + let assemblyFormat = "operands $mask attr-dict `:` type(operands)"; + + let hasFolder = 1; let hasVerifier = 1; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index f6b1c42dcd24c4..c787f23a75f3c1 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -994,7 +994,7 @@ class VectorShuffleOpConversion auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getResultVectorType(); Type llvmType = typeConverter->convertType(vectorType); - auto maskArrayAttr = shuffleOp.getMask(); + auto mask = shuffleOp.getMask(); // Bail if result type cannot be lowered. if (!llvmType) @@ -1014,8 +1014,7 @@ class VectorShuffleOpConversion // type, there is direct shuffle support in LLVM. Use it! if (rank <= 1 && v1Type == v2Type) { Value llvmShuffleOp = rewriter.create( - loc, adaptor.getV1(), adaptor.getV2(), - LLVM::convertArrayToIndices(maskArrayAttr)); + loc, adaptor.getV1(), adaptor.getV2(), SmallVector(mask)); rewriter.replaceOp(shuffleOp, llvmShuffleOp); return success(); } @@ -1029,8 +1028,7 @@ class VectorShuffleOpConversion eltType = cast(llvmType).getElementType(); Value insert = rewriter.create(loc, llvmType); int64_t insPos = 0; - for (const auto &en : llvm::enumerate(maskArrayAttr)) { - int64_t extPos = cast(en.value()).getInt(); + for (int64_t extPos : mask) { Value value = adaptor.getV1(); if (extPos >= v1Dim) { extPos -= v1Dim; diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 890706bf1bb2e3..db08457be8e5a5 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -527,10 +527,8 @@ struct VectorShuffleOpConvert final return rewriter.notifyMatchFailure(shuffleOp, "unsupported result vector type"); - SmallVector mask = llvm::map_to_vector<4>( - shuffleOp.getMask(), [](Attribute attr) -> int32_t { - return cast(attr).getValue().getZExtValue(); - }); + // Cast mask from int64_t to int32_t. + SmallVector mask(shuffleOp.getMask()); VectorType oldV1Type = shuffleOp.getV1VectorType(); VectorType oldV2Type = shuffleOp.getV2VectorType(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 669ae586e57861..5047bd925d4c5d 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2464,11 +2464,6 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, // ShuffleOp //===----------------------------------------------------------------------===// -void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, - Value v2, ArrayRef mask) { - build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask)); -} - LogicalResult ShuffleOp::verify() { VectorType resultType = getResultVectorType(); VectorType v1Type = getV1VectorType(); @@ -2491,8 +2486,8 @@ LogicalResult ShuffleOp::verify() { return emitOpError("dimension mismatch"); } // Verify mask length. - auto maskAttr = getMask().getValue(); - int64_t maskLength = maskAttr.size(); + ArrayRef mask = getMask(); + int64_t maskLength = mask.size(); if (maskLength <= 0) return emitOpError("invalid mask length"); if (maskLength != resultType.getDimSize(0)) @@ -2500,10 +2495,9 @@ LogicalResult ShuffleOp::verify() { // Verify all indices. int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); - for (const auto &en : llvm::enumerate(maskAttr)) { - auto attr = llvm::dyn_cast(en.value()); - if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) - return emitOpError("mask index #") << (en.index() + 1) << " out of range"; + for (auto [idx, maskPos] : llvm::enumerate(mask)) { + if (maskPos < 0 || maskPos >= indexSize) + return emitOpError("mask index #") << (idx + 1) << " out of range"; } return success(); } @@ -2527,13 +2521,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional, return success(); } -static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) { - uint64_t expected = begin; - return idxArr.size() == width && - llvm::all_of(idxArr.getAsValueRange(), - [&expected](auto attr) { - return attr.getZExtValue() == expected++; - }); +template +static bool isStepIndexArray(ArrayRef idxArr, uint64_t begin, size_t width) { + T expected = begin; + return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) { + return value == expected++; + }); } OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) { @@ -2568,8 +2561,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) { SmallVector results; auto lhsElements = llvm::cast(lhs).getValues(); auto rhsElements = llvm::cast(rhs).getValues(); - for (const auto &index : this->getMask().getAsValueRange()) { - int64_t i = index.getZExtValue(); + for (int64_t i : this->getMask()) { if (i >= lhsSize) { results.push_back(rhsElements[i - lhsSize]); } else { @@ -2590,13 +2582,13 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern { LogicalResult matchAndRewrite(ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { VectorType v1VectorType = shuffleOp.getV1VectorType(); - ArrayAttr mask = shuffleOp.getMask(); + ArrayRef mask = shuffleOp.getMask(); if (v1VectorType.getRank() > 0) return failure(); if (mask.size() != 1) return failure(); VectorType resType = VectorType::Builder(v1VectorType).setShape({1}); - if (llvm::cast(mask[0]).getInt() == 0) + if (mask[0] == 0) rewriter.replaceOpWithNewOp(shuffleOp, resType, shuffleOp.getV1()); else @@ -2651,11 +2643,11 @@ class ShuffleInterleave : public OpRewritePattern { op, "ShuffleOp types don't match an interleave"); } - ArrayAttr shuffleMask = op.getMask(); + ArrayRef shuffleMask = op.getMask(); int64_t resultVectorSize = resultType.getNumElements(); for (int i = 0, e = resultVectorSize / 2; i < e; ++i) { - int64_t maskValueA = cast(shuffleMask[i * 2]).getInt(); - int64_t maskValueB = cast(shuffleMask[(i * 2) + 1]).getInt(); + int64_t maskValueA = shuffleMask[i * 2]; + int64_t maskValueB = shuffleMask[(i * 2) + 1]; if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i) return rewriter.notifyMatchFailure(op, "ShuffleOp mask not interleaving"); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index 37216cea7b6150..ec2ef3fc7501c2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -225,8 +225,7 @@ class Convert1DExtractStridedSliceIntoShuffle off += stride) offsets.push_back(off); rewriter.replaceOpWithNewOp(op, dstType, op.getVector(), - op.getVector(), - rewriter.getI64ArrayAttr(offsets)); + op.getVector(), offsets); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 4a3ae1b850517e..868397f2daaae4 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -232,8 +232,7 @@ struct LinearizeVectorExtractStridedSlice final } // Perform a shuffle to extract the kD vector. rewriter.replaceOpWithNewOp( - extractOp, dstType, srcVector, srcVector, - rewriter.getI64ArrayAttr(indices)); + extractOp, dstType, srcVector, srcVector, indices); return success(); } @@ -298,20 +297,17 @@ struct LinearizeVectorShuffle final // that needs to be shuffled to the destination vector. If shuffleSliceLen > // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of // elements) instead of scalars. - ArrayAttr mask = shuffleOp.getMask(); + ArrayRef mask = shuffleOp.getMask(); int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; llvm::SmallVector indices(totalSizeOfShuffledElmnts); - for (auto [i, value] : - llvm::enumerate(mask.getAsValueRange())) { - - int64_t v = value.getZExtValue(); + for (auto [i, value] : llvm::enumerate(mask)) { std::iota(indices.begin() + shuffleSliceLen * i, indices.begin() + shuffleSliceLen * (i + 1), - shuffleSliceLen * v); + shuffleSliceLen * value); } - rewriter.replaceOpWithNewOp( - shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices)); + rewriter.replaceOpWithNewOp(shuffleOp, dstType, vec1, + vec2, indices); return success(); } @@ -368,8 +364,7 @@ struct LinearizeVectorExtract final llvm::SmallVector indices(size); std::iota(indices.begin(), indices.end(), linearizedOffset); rewriter.replaceOpWithNewOp( - extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), - rewriter.getI64ArrayAttr(indices)); + extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices); return success(); } @@ -452,8 +447,7 @@ struct LinearizeVectorInsert final // [offset+srcNumElements, end) rewriter.replaceOpWithNewOp( - insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), - rewriter.getI64ArrayAttr(indices)); + insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices); return success(); }