Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][linalg] Fix unpack rewriter for dynamic shapes #67096

Merged
merged 1 commit into from
Sep 22, 2023

Conversation

qcolombet
Copy link
Collaborator

Prior to this patch, GeneralizeOuterUnitDimsUnPackOpPattern would assert that we cannot create a tensor.empty operation with dynamic shapes.

The problem stems from the fact that we were not using the right builder for the tensor.empty operation. Indeed, each dynamic dim needs to be specified by an input variable.

Simply provide the dynamic dimensions to the tensor.empty builder to fix that.

Prior to this patch, `GeneralizeOuterUnitDimsUnPackOpPattern` would assert
that we cannot create a `tensor.empty` operation with dynamic shapes.

The problem stems from the fact that we were not using the right builder
for the `tensor.empty` operation. Indeed, each dynamic dim needs to be
specified by an input variable.

Simply provide the dynamic dimensions to the `tensor.empty` builder to fix
that.
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 22, 2023

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Changes

Prior to this patch, GeneralizeOuterUnitDimsUnPackOpPattern would assert that we cannot create a tensor.empty operation with dynamic shapes.

The problem stems from the fact that we were not using the right builder for the tensor.empty operation. Indeed, each dynamic dim needs to be specified by an input variable.

Simply provide the dynamic dimensions to the tensor.empty builder to fix that.


Full diff: https://github.com/llvm/llvm-project/pull/67096.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+7-3)
  • (modified) mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir (+23)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 49fe937741c77c9..8183b40ad7346f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1256,6 +1256,7 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> readSizes;
   SmallVector<int64_t> readShape;
+  SmallVector<Value> dynamicDims;
   for (auto i : llvm::seq<unsigned>(0, destRank)) {
     if (dimAndTileMapping.count(i)) {
       readSizes.push_back(oneIdxAttr);
@@ -1263,8 +1264,10 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
     }
 
     if (ShapedType::isDynamic(srcShape[i])) {
-      readSizes.push_back(
-          rewriter.create<tensor::DimOp>(loc, source, i).getResult());
+      Value dynamicDim =
+          rewriter.create<tensor::DimOp>(loc, source, i).getResult();
+      readSizes.push_back(dynamicDim);
+      dynamicDims.push_back(dynamicDim);
     } else {
       readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
     }
@@ -1292,7 +1295,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   SmallVector<int64_t> transpShape(readShape);
   applyPermutationToVector<int64_t>(transpShape, perm);
 
-  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+  Value empty =
+      rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
 
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index a596690c2e4fd60..023768088650062 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -94,3 +94,26 @@ func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tensor<?x1x32x8xf32>) -> tensor<?x1x32x8xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<?x1x1x1x8x32xf32> -> tensor<?x1x32x8xf32>
+  return %0 : tensor<?x1x32x8xf32>
+}
+// CHECK-LABEL: func.func @unpack_with_dynamic_dims
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[DIM0_SRC:.+]] = tensor.dim %[[SRC]], %[[C0]] : tensor<?x1x1x1x8x32xf32>
+// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0, 0] [%[[DIM0_SRC]], 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM0_SRC]]) : tensor<?x32x8xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[TILE]] : tensor<?x8x32xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<?x32x8xf32>)
+// CHECK-SAME:      permutation = [0, 2, 1]
+// CHECK:         %[[DIM0_DEST:.+]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x1x32x8xf32>
+// CHECK:         %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0, 0] [%[[DIM0_DEST]], 32, 8] [1, 1, 1] : tensor<?x32x8xf32> to tensor<?x32x8xf32>
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
+// CHECK:         return %[[INSERT]]

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@qcolombet qcolombet merged commit a44b787 into llvm:main Sep 22, 2023
4 checks passed
@qcolombet qcolombet deleted the fix_unpack branch September 22, 2023 10:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants