Skip to content

Commit 7edff84

Browse files
committed
[mlir][memref] Fix computeCollapsedLayoutMap for contiguous dynamic dim
1 parent a50cb6c commit 7edff84

File tree

3 files changed

+46
-6
lines changed

3 files changed

+46
-6
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Interfaces/ViewLikeInterface.h"
2424
#include "llvm/ADT/STLExtras.h"
2525
#include "llvm/ADT/SmallBitVector.h"
26+
#include <algorithm>
2627

2728
using namespace mlir;
2829
using namespace mlir::memref;
@@ -2413,11 +2414,23 @@ computeCollapsedLayoutMap(MemRefType srcType,
24132414
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
24142415
resultStrides.push_back(srcStrides[ref.back()]);
24152416
} else {
2416-
// Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2417-
// the corresponding stride may have to be skipped. (See above comment.)
2418-
// Therefore, the result stride cannot be statically determined and must
2419-
// be dynamic.
2420-
resultStrides.push_back(ShapedType::kDynamic);
2417+
// We reach here if the last dimension in the reassociation group is dynamic,
2418+
// and the reassociation group has more than one dimension.
2419+
// If the dynamic dimension is preserved (all other dimensions in the group are of size 1),
2420+
// and the dynamic dimension is originally contiguous, the result stride will be 1.
2421+
bool contiguousSrcDim = srcStrides[ref.back()] == 1;
2422+
bool dynamicSizeIsPreserved =
2423+
std::all_of(ref.begin(), ref.end() - 1,
2424+
[srcShape](int64_t dim) { return srcShape[dim] == 1; });
2425+
if (contiguousSrcDim && dynamicSizeIsPreserved)
2426+
resultStrides.push_back(1);
2427+
else {
2428+
// Dynamically-sized dims may turn out to be dims of size 1 at runtime,
2429+
// so the corresponding stride may have to be skipped. (See above
2430+
// comment.) Therefore, the result stride cannot be statically
2431+
// determined and must be dynamic.
2432+
resultStrides.push_back(ShapedType::kDynamic);
2433+
}
24212434
}
24222435
}
24232436

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,14 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
502502

503503
// -----
504504

505+
func.func @collapse_shape_infer_stride_of_dynamic_dim(%arg0: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, %dim : index) -> (memref<?xsi32, strided<[?]>, 1>) {
506+
// expected-error @+1 {{expected collapsed type to be 'memref<?xsi32, strided<[1]>, 1>' but found 'memref<?xsi32, strided<[?]>, 1>'}}
507+
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into memref<?xsi32, strided<[?]>, 1>
508+
return %collapse_shape : memref<?xsi32, strided<[?]>, 1>
509+
}
510+
511+
// -----
512+
505513
func.func @expand_shape_illegal_static_memref
506514
(%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
507515
// expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
417417
%arg4: index,
418418
%arg5: index,
419419
%arg6: index,
420-
%arg7: memref<4x?x4xf32>) {
420+
%arg7: memref<4x?x4xf32>,
421+
%arg8: memref<1x2x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>,
422+
%arg9: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>,
423+
%arg10: memref<1x1x?x1xsi32, strided<[36960, 4620, 2, 330]>, 1>) {
421424
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
422425
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
423426
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -466,6 +469,22 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
466469
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
467470
%4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
468471
: memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
472+
473+
// CHECK: collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]]
474+
%5 = memref.collapse_shape %arg8 [[0, 1, 2, 3]] :
475+
memref<1x2x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into
476+
memref<?xsi32, strided<[?]>, 1>
477+
478+
// CHECK: collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]]
479+
%6 = memref.collapse_shape %arg9 [[0, 1, 2, 3]] :
480+
memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into
481+
memref<?xsi32, strided<[1]>, 1>
482+
483+
// CHECK: collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]]
484+
%7 = memref.collapse_shape %arg10 [[0, 1, 2, 3]] :
485+
memref<1x1x?x1xsi32, strided<[36960, 4620, 2, 330]>, 1> into
486+
memref<?xsi32, strided<[?]>, 1>
487+
469488
return
470489
}
471490

0 commit comments

Comments
 (0)