diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 49fe937741c77c..8183b40ad7346f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1256,6 +1256,7 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite( SmallVector readStrides(srcRank, oneIdxAttr); SmallVector readSizes; SmallVector readShape; + SmallVector dynamicDims; for (auto i : llvm::seq(0, destRank)) { if (dimAndTileMapping.count(i)) { readSizes.push_back(oneIdxAttr); @@ -1263,8 +1264,10 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite( } if (ShapedType::isDynamic(srcShape[i])) { - readSizes.push_back( - rewriter.create(loc, source, i).getResult()); + Value dynamicDim = + rewriter.create(loc, source, i).getResult(); + readSizes.push_back(dynamicDim); + dynamicDims.push_back(dynamicDim); } else { readSizes.push_back(rewriter.getIndexAttr(srcShape[i])); } @@ -1292,7 +1295,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite( SmallVector transpShape(readShape); applyPermutationToVector(transpShape, perm); - Value empty = rewriter.create(loc, transpShape, elemType); + Value empty = + rewriter.create(loc, transpShape, elemType, dynamicDims); auto transposedOp = rewriter.create(loc, innerTile, empty, perm); diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir index a596690c2e4fd6..02376808865006 100644 --- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir @@ -94,3 +94,26 @@ func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1] // CHECK: return %[[INSERT]] + +// ----- + +func.func @unpack_with_dynamic_dims(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor -> tensor + return %0 : tensor +} +// CHECK-LABEL: func.func @unpack_with_dynamic_dims +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM0_SRC:.+]] = tensor.dim %[[SRC]], %[[C0]] : tensor +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0, 0] [%[[DIM0_SRC]], 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0_SRC]]) : tensor +// CHECK: %[[TRANSP:.+]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] : tensor) +// CHECK-SAME: outs(%[[EMPTY]] : tensor) +// CHECK-SAME: permutation = [0, 2, 1] +// CHECK: %[[DIM0_DEST:.+]] = tensor.dim %[[DEST]], %[[C0]] : tensor +// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0, 0] [%[[DIM0_DEST]], 32, 8] [1, 1, 1] : tensor to tensor +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1] +// CHECK: return %[[INSERT]]