Skip to content

Commit

Permalink
[MLIR][linalg] Fix unpack rewriter for dynamic shapes (#67096)
Browse files Browse the repository at this point in the history
Prior to this patch, `GeneralizeOuterUnitDimsUnPackOpPattern` would
assert that we cannot create a `tensor.empty` operation with dynamic
shapes.

The problem stems from the fact that we were not using the right builder
for the `tensor.empty` operation. Indeed, each dynamic dim needs to be
specified by an input variable.

Simply provide the dynamic dimensions to the `tensor.empty` builder to
fix that.
  • Loading branch information
qcolombet committed Sep 22, 2023
1 parent 39d7f70 commit a44b787
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
10 changes: 7 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1256,15 +1256,18 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
SmallVector<int64_t> readShape;
SmallVector<Value> dynamicDims;
for (auto i : llvm::seq<unsigned>(0, destRank)) {
if (dimAndTileMapping.count(i)) {
readSizes.push_back(oneIdxAttr);
continue;
}

if (ShapedType::isDynamic(srcShape[i])) {
readSizes.push_back(
rewriter.create<tensor::DimOp>(loc, source, i).getResult());
Value dynamicDim =
rewriter.create<tensor::DimOp>(loc, source, i).getResult();
readSizes.push_back(dynamicDim);
dynamicDims.push_back(dynamicDim);
} else {
readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
}
Expand Down Expand Up @@ -1292,7 +1295,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
SmallVector<int64_t> transpShape(readShape);
applyPermutationToVector<int64_t>(transpShape, perm);

Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
Value empty =
rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);

Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,26 @@ func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]

// -----

func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tensor<?x1x32x8xf32>) -> tensor<?x1x32x8xf32> {
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<?x1x1x1x8x32xf32> -> tensor<?x1x32x8xf32>
return %0 : tensor<?x1x32x8xf32>
}
// CHECK-LABEL: func.func @unpack_with_dynamic_dims
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM0_SRC:.+]] = tensor.dim %[[SRC]], %[[C0]] : tensor<?x1x1x1x8x32xf32>
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0, 0] [%[[DIM0_SRC]], 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0_SRC]]) : tensor<?x32x8xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[TILE]] : tensor<?x8x32xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x32x8xf32>)
// CHECK-SAME: permutation = [0, 2, 1]
// CHECK: %[[DIM0_DEST:.+]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x1x32x8xf32>
// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0, 0] [%[[DIM0_DEST]], 32, 8] [1, 1, 1] : tensor<?x32x8xf32> to tensor<?x32x8xf32>
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]

0 comments on commit a44b787

Please sign in to comment.