Skip to content

[mlir][memref] Fix computeCollapsedLayoutMap for contiguous dynamic dim #136485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include <algorithm>

using namespace mlir;
using namespace mlir::memref;
Expand Down Expand Up @@ -2413,11 +2414,24 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
resultStrides.push_back(srcStrides[ref.back()]);
} else {
// Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
// the corresponding stride may have to be skipped. (See above comment.)
// Therefore, the result stride cannot be statically determined and must
// be dynamic.
resultStrides.push_back(ShapedType::kDynamic);
// We reach here if the last dimension in the reassociation group is
// dynamic, and the reassociation group has more than one dimension.
// If the dynamic dim is preserved (all other dimensions in the group are
// of size 1), and the dynamic dim is originally contiguous, the result
// stride will be 1.
bool contiguousSrcDim = srcStrides[ref.back()] == 1;
bool dynamicSizeIsPreserved =
std::all_of(ref.begin(), ref.end() - 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it ref.begin() -> ref.end() - 1? All dimensions except for one must be 1, right? In that case, it does not matter where non-unit dimension is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some explanation.
Hope it is clearer now.

[srcShape](int64_t dim) { return srcShape[dim] == 1; });
if (contiguousSrcDim && dynamicSizeIsPreserved)
resultStrides.push_back(1);
else {
// Dynamically-sized dims may turn out to be dims of size 1 at runtime,
// so the corresponding stride may have to be skipped. (See above
// comment.) Therefore, the result stride cannot be statically
// determined and must be dynamic.
resultStrides.push_back(ShapedType::kDynamic);
}
}
}

Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {

// -----

func.func @collapse_shape_infer_stride_of_dynamic_dim(%arg0: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, %dim : index) -> (memref<?xsi32, strided<[?]>, 1>) {
// expected-error @+1 {{expected collapsed type to be 'memref<?xsi32, strided<[1]>, 1>' but found 'memref<?xsi32, strided<[?]>, 1>'}}
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into memref<?xsi32, strided<[?]>, 1>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add another example where not all of the static source dimensions are 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return %collapse_shape : memref<?xsi32, strided<[?]>, 1>
}

// -----

func.func @expand_shape_illegal_static_memref
(%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
// expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
Expand Down
21 changes: 20 additions & 1 deletion mlir/test/Dialect/MemRef/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
%arg4: index,
%arg5: index,
%arg6: index,
%arg7: memref<4x?x4xf32>) {
%arg7: memref<4x?x4xf32>,
%arg8: memref<1x2x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>,
%arg9: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>,
%arg10: memref<1x1x?x1xsi32, strided<[36960, 4620, 2, 330]>, 1>) {
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
Expand Down Expand Up @@ -466,6 +469,22 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
%4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
: memref<4x?x4xf32> into memref<2x2x?x2x2xf32>

// CHECK: collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]]
%5 = memref.collapse_shape %arg8 [[0, 1, 2, 3]] :
memref<1x2x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into
memref<?xsi32, strided<[?]>, 1>

// CHECK: collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]]
%6 = memref.collapse_shape %arg9 [[0, 1, 2, 3]] :
memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into
memref<?xsi32, strided<[1]>, 1>

// CHECK: collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]]
%7 = memref.collapse_shape %arg10 [[0, 1, 2, 3]] :
memref<1x1x?x1xsi32, strided<[36960, 4620, 2, 330]>, 1> into
memref<?xsi32, strided<[?]>, 1>

return
}

Expand Down
Loading