diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index d56b32193765e..b6a1113b61597 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -23,6 +23,7 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include using namespace mlir; using namespace mlir::memref; @@ -2413,11 +2414,24 @@ computeCollapsedLayoutMap(MemRefType srcType, if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { resultStrides.push_back(srcStrides[ref.back()]); } else { - // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so - // the corresponding stride may have to be skipped. (See above comment.) - // Therefore, the result stride cannot be statically determined and must - // be dynamic. - resultStrides.push_back(ShapedType::kDynamic); + // We reach here if the last dimension in the reassociation group is + // dynamic, and the reassociation group has more than one dimension. + // If the dynamic dim is preserved (all other dimensions in the group are + // of size 1), and the dynamic dim is originally contiguous, the result + // stride will be 1. + bool contiguousSrcDim = srcStrides[ref.back()] == 1; + bool dynamicSizeIsPreserved = + std::all_of(ref.begin(), ref.end() - 1, + [srcShape](int64_t dim) { return srcShape[dim] == 1; }); + if (contiguousSrcDim && dynamicSizeIsPreserved) + resultStrides.push_back(1); + else { + // Dynamically-sized dims may turn out to be dims of size 1 at runtime, + // so the corresponding stride may have to be skipped. (See above + // comment.) Therefore, the result stride cannot be statically + // determined and must be dynamic. + resultStrides.push_back(ShapedType::kDynamic); + } } } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index f908efb638446..8c9d744a754ce 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -502,6 +502,14 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref) { // ----- +func.func @collapse_shape_infer_stride_of_dynamic_dim(%arg0: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, %dim : index) -> (memref, 1>) { + // expected-error @+1 {{expected collapsed type to be 'memref, 1>' but found 'memref, 1>'}} + %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into memref, 1> + return %collapse_shape : memref, 1> +} + +// ----- + func.func @expand_shape_illegal_static_memref (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> { // expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 13fdf3cf13510..123ac1cf4de94 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -417,7 +417,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, %arg4: index, %arg5: index, %arg6: index, - %arg7: memref<4x?x4xf32>) { + %arg7: memref<4x?x4xf32>, + %arg8: memref<1x2x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, + %arg9: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, + %arg10: memref<1x1x?x1xsi32, strided<[36960, 4620, 2, 330]>, 1>) { // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : @@ -466,6 +469,22 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2] : memref<4x?x4xf32> into memref<2x2x?x2x2xf32> + +// CHECK: collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]] + %5 = memref.collapse_shape %arg8 [[0, 1, 2, 3]] : + memref<1x2x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into + memref, 1> + +// CHECK: collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]] + %6 = memref.collapse_shape %arg9 [[0, 1, 2, 3]] : + memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into + memref, 1> + +// CHECK: collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]] + %7 = memref.collapse_shape %arg10 [[0, 1, 2, 3]] : + memref<1x1x?x1xsi32, strided<[36960, 4620, 2, 330]>, 1> into + memref, 1> + return }