Skip to content

Commit

Permalink
[mlir][Interfaces] Clean up DestinationStyleOpInterface
Browse files Browse the repository at this point in the history
* "init" operands are specified with `MutableOperandRange` (which gives access to the underlying `OpOperand *`). No more magic numbers.
* Remove most interface methods and make them helper functions. Only `getInitsMutable` should be implemented.
* Provide separate helper functions for accessing mutable/immutable operands (`OpOperand`/`Value`, in line with #66515): `getInitsMutable` and `getInits` (same naming convention as auto-generated op accessors). `getInputOperands` was not renamed because this function cannot return a `MutableOperandRange` (because the operands are not necessarily consecutive). `OpOperandVector` is no longer needed.
* The new `getDpsInits`/`getDpsInitsMutable` is more efficient than the old `getDpsInitOperands` because no `SmallVector` is created. The new functions return a range of operands.
* Fix a bug in `getDpsInputOperands`: out-of-bounds operands were potentially returned.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
  • Loading branch information
matthias-springer committed Sep 21, 2023
1 parent 7fcbb64 commit 89f47c7
Show file tree
Hide file tree
Showing 34 changed files with 356 additions and 484 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,7 @@ def Bufferization_MaterializeInDestinationOp
return ::llvm::cast<RankedTensorType>(getResult().getType());
}

std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {1, 2}; // `dest` operand
}
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
}];

let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,12 @@ def LinalgStructuredInterface
are expection. For example, in `map` output operand isn't used in
the block.
}],
/*retTy=*/"OpOperandVector",
/*retTy=*/"::llvm::SmallVector<OpOperand *>",
/*methodName=*/"getOpOperandsMatchingBBargs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
OpOperandVector result;
::llvm::SmallVector<OpOperand *> result;
result.reserve($_op->getNumOperands());
llvm::transform(
this->getOperation()->getOpOperands(),
Expand Down
9 changes: 1 addition & 8 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
int64_t getOutputOperandRank() {
return getOutputOperandType().getRank();
}
// Method to implement DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
std::pair<unsigned, unsigned> outputsIndexAndLength =
getODSOperandIndexAndLength(1);
return std::make_pair<int64_t, int64_t>(
outputsIndexAndLength.first,
outputsIndexAndLength.first + outputsIndexAndLength.second);
}
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}
Expand Down
28 changes: 8 additions & 20 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
getRegionBuilder() {
return nullptr;
}
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - getOutputs().size(), getNumOperands};
}

MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
}];

let hasCanonicalizer = 1;
Expand Down Expand Up @@ -283,11 +281,9 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}

// Implement functions necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
}
OpOperandVector getOpOperandsMatchingBBargs() {
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
return getDpsInputOperands();
}

Expand Down Expand Up @@ -381,9 +377,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
getRegionBuilder() {
return nullptr;
}
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {getInits().size(), getNumOperands()};
}
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
}];

let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -446,10 +440,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
}

// Implement functions necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
}
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
Expand Down Expand Up @@ -517,10 +508,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
}

// Implement functions necessary for DestinationStyleOpInterface.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
}
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
Expand Down
13 changes: 3 additions & 10 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
}];

let extraClassDeclaration = [{
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {1, 2}; // `dest` operand
}
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
}];

let hasFolder = 1;
Expand Down Expand Up @@ -892,9 +890,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }

std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {1, 2}; // `dest` operand
}
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
}];

let hasCanonicalizer = 1;
Expand Down Expand Up @@ -1714,10 +1710,7 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
RankedTensorType getDestType() {
return ::llvm::cast<RankedTensorType>(getDest().getType()); };

/// Return position for init operand. Init operand is `dest`.
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {1, 2}; // `dest` operand
}
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }

/// Interface method for ConditionallySpeculatable.
Speculation::Speculatability getSpeculatability();
Expand Down
8 changes: 3 additions & 5 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1330,8 +1330,8 @@ def Vector_TransferReadOp :
// MaskableOpInterface methods.
bool supportsPassthru() { return true; }

std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {0, 0}; // empty range (no init operands)
MutableOperandRange getDpsInitsMutable() {
return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
}
}];

Expand Down Expand Up @@ -1494,9 +1494,7 @@ def Vector_TransferWriteOp :
/// ops of other dialects.
Value getValue() { return getVector(); }

std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {1, 2}; // `source` operand
}
MutableOperandRange getDpsInitsMutable() { return getSourceMutable(); }
}];

let hasFolder = 1;
Expand Down
5 changes: 0 additions & 5 deletions mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
#include "llvm/ADT/SmallVector.h"

namespace mlir {
/// OpOperand vector that implicitly converts to a Value vector.
struct OpOperandVector : public llvm::SmallVector<OpOperand *> {
operator SmallVector<Value>();
};

namespace detail {
/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface
LogicalResult verifyDestinationStyleOpInterface(Operation *op);
Expand Down
Loading

0 comments on commit 89f47c7

Please sign in to comment.