Skip to content

Commit

Permalink
[mlir][vector] Use DenseI64ArrayAttr for shuffle masks
Browse files Browse the repository at this point in the history
Follow on from llvm#100997. This again removes from boilerplate conversions
to/from IntegerAttr and int64_t (otherwise, this is a NFC).
  • Loading branch information
MacDue committed Jul 30, 2024
1 parent 95e9aff commit f874f02
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 56 deletions.
11 changes: 5 additions & 6 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def Vector_ShuffleOp :
TCresVTEtIsSameAsOpBase<0, 1>>,
InferTypeOpAdaptor]>,
Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
I64ArrayAttr:$mask)>,
DenseI64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
let description = [{
Expand Down Expand Up @@ -459,11 +459,7 @@ def Vector_ShuffleOp :
: vector<f32>, vector<f32> ; yields vector<2xf32>
```
}];
let builders = [
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
];
let hasFolder = 1;
let hasCanonicalizer = 1;

let extraClassDeclaration = [{
VectorType getV1VectorType() {
return ::llvm::cast<VectorType>(getV1().getType());
Expand All @@ -475,7 +471,10 @@ def Vector_ShuffleOp :
return ::llvm::cast<VectorType>(getVector().getType());
}
}];

let assemblyFormat = "operands $mask attr-dict `:` type(operands)";

let hasFolder = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
Expand Down
8 changes: 3 additions & 5 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ class VectorShuffleOpConversion
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getResultVectorType();
Type llvmType = typeConverter->convertType(vectorType);
auto maskArrayAttr = shuffleOp.getMask();
auto mask = shuffleOp.getMask();

// Bail if result type cannot be lowered.
if (!llvmType)
Expand All @@ -1014,8 +1014,7 @@ class VectorShuffleOpConversion
// type, there is direct shuffle support in LLVM. Use it!
if (rank <= 1 && v1Type == v2Type) {
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.getV1(), adaptor.getV2(),
LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
loc, adaptor.getV1(), adaptor.getV2(), SmallVector<int32_t>(mask));
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
}
Expand All @@ -1029,8 +1028,7 @@ class VectorShuffleOpConversion
eltType = cast<VectorType>(llvmType).getElementType();
Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
int64_t insPos = 0;
for (const auto &en : llvm::enumerate(maskArrayAttr)) {
int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
for (int64_t extPos : mask) {
Value value = adaptor.getV1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,8 @@ struct VectorShuffleOpConvert final
return rewriter.notifyMatchFailure(shuffleOp,
"unsupported result vector type");

SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
shuffleOp.getMask(), [](Attribute attr) -> int32_t {
return cast<IntegerAttr>(attr).getValue().getZExtValue();
});
// Cast mask from int64_t to int32_t.
SmallVector<int32_t> mask(shuffleOp.getMask());

VectorType oldV1Type = shuffleOp.getV1VectorType();
VectorType oldV2Type = shuffleOp.getV2VectorType();
Expand Down
42 changes: 17 additions & 25 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2464,11 +2464,6 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ShuffleOp
//===----------------------------------------------------------------------===//

void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
Value v2, ArrayRef<int64_t> mask) {
build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
}

LogicalResult ShuffleOp::verify() {
VectorType resultType = getResultVectorType();
VectorType v1Type = getV1VectorType();
Expand All @@ -2491,19 +2486,18 @@ LogicalResult ShuffleOp::verify() {
return emitOpError("dimension mismatch");
}
// Verify mask length.
auto maskAttr = getMask().getValue();
int64_t maskLength = maskAttr.size();
ArrayRef<int64_t> mask = getMask();
int64_t maskLength = mask.size();
if (maskLength <= 0)
return emitOpError("invalid mask length");
if (maskLength != resultType.getDimSize(0))
return emitOpError("mask length mismatch");
// Verify all indices.
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
for (const auto &en : llvm::enumerate(maskAttr)) {
auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
return emitOpError("mask index #") << (en.index() + 1) << " out of range";
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
if (maskPos < 0 || maskPos >= indexSize)
return emitOpError("mask index #") << (idx + 1) << " out of range";
}
return success();
}
Expand All @@ -2527,13 +2521,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
return success();
}

static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
uint64_t expected = begin;
return idxArr.size() == width &&
llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
[&expected](auto attr) {
return attr.getZExtValue() == expected++;
});
template <typename T>
static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
T expected = begin;
return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
return value == expected++;
});
}

OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
Expand Down Expand Up @@ -2568,8 +2561,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
SmallVector<Attribute> results;
auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
int64_t i = index.getZExtValue();
for (int64_t i : this->getMask()) {
if (i >= lhsSize) {
results.push_back(rhsElements[i - lhsSize]);
} else {
Expand All @@ -2590,13 +2582,13 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
PatternRewriter &rewriter) const override {
VectorType v1VectorType = shuffleOp.getV1VectorType();
ArrayAttr mask = shuffleOp.getMask();
ArrayRef<int64_t> mask = shuffleOp.getMask();
if (v1VectorType.getRank() > 0)
return failure();
if (mask.size() != 1)
return failure();
VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
if (mask[0] == 0)
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
shuffleOp.getV1());
else
Expand Down Expand Up @@ -2651,11 +2643,11 @@ class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
op, "ShuffleOp types don't match an interleave");
}

ArrayAttr shuffleMask = op.getMask();
ArrayRef<int64_t> shuffleMask = op.getMask();
int64_t resultVectorSize = resultType.getNumElements();
for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
int64_t maskValueA = shuffleMask[i * 2];
int64_t maskValueB = shuffleMask[(i * 2) + 1];
if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
return rewriter.notifyMatchFailure(op,
"ShuffleOp mask not interleaving");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ class Convert1DExtractStridedSliceIntoShuffle
off += stride)
offsets.push_back(off);
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
op.getVector(),
rewriter.getI64ArrayAttr(offsets));
op.getVector(), offsets);
return success();
}
};
Expand Down
22 changes: 8 additions & 14 deletions mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ struct LinearizeVectorExtractStridedSlice final
}
// Perform a shuffle to extract the kD vector.
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
extractOp, dstType, srcVector, srcVector,
rewriter.getI64ArrayAttr(indices));
extractOp, dstType, srcVector, srcVector, indices);
return success();
}

Expand Down Expand Up @@ -298,20 +297,17 @@ struct LinearizeVectorShuffle final
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
// elements) instead of scalars.
ArrayAttr mask = shuffleOp.getMask();
ArrayRef<int64_t> mask = shuffleOp.getMask();
int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
for (auto [i, value] :
llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {

int64_t v = value.getZExtValue();
for (auto [i, value] : llvm::enumerate(mask)) {
std::iota(indices.begin() + shuffleSliceLen * i,
indices.begin() + shuffleSliceLen * (i + 1),
shuffleSliceLen * v);
shuffleSliceLen * value);
}

rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
vec2, indices);
return success();
}

Expand Down Expand Up @@ -368,8 +364,7 @@ struct LinearizeVectorExtract final
llvm::SmallVector<int64_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), linearizedOffset);
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
rewriter.getI64ArrayAttr(indices));
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);

return success();
}
Expand Down Expand Up @@ -452,8 +447,7 @@ struct LinearizeVectorInsert final
// [offset+srcNumElements, end)

rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
rewriter.getI64ArrayAttr(indices));
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);

return success();
}
Expand Down

0 comments on commit f874f02

Please sign in to comment.