Skip to content

Commit 07c043c

Browse files
committed
[MLIR][MemRef]-Add basic folding for memref ViewOp
Add a folding for MemRef::ViewOp where the source memref type and the result memref type are similar.
1 parent b8f1228 commit 07c043c

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,6 +2333,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
23332333

23342334
let hasCanonicalizer = 1;
23352335
let hasVerifier = 1;
2336+
let hasFolder = 1;
23362337
}
23372338

23382339
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3463,6 +3463,16 @@ LogicalResult ViewOp::verify() {
34633463

34643464
Value ViewOp::getViewSource() { return getSource(); }
34653465

3466+
OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3467+
MemRefType sourceMemrefType = getSource().getType();
3468+
MemRefType resultMemrefType = getResult().getType();
3469+
3470+
if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3471+
return getViewSource();
3472+
3473+
return {};
3474+
}
3475+
34663476
namespace {
34673477

34683478
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,3 +1208,24 @@ func.func @fold_assume_alignment_chain(%0: memref<128xf32>) -> memref<128xf32> {
12081208
// CHECK: return %[[ALIGN]]
12091209
return %2 : memref<128xf32>
12101210
}
1211+
1212+
// -----
1213+
1214+
// CHECK-LABEL: func @fold_view_same_source_result_types
1215+
func.func @fold_view_same_source_result_types(%0: memref<128xi8>) -> memref<128xi8> {
1216+
%c0 = arith.constant 0: index
1217+
// CHECK-NOT: memref.view
1218+
%res = memref.view %0[%c0][] : memref<128xi8> to memref<128xi8>
1219+
return %res : memref<128xi8>
1220+
}
1221+
1222+
// -----
1223+
1224+
// CHECK-LABEL: func @non_fold_view_same_source_res_types
1225+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1226+
func.func @non_fold_view_same_source_res_types(%0: memref<?xi8>, %arg0 : index) -> memref<?xi8> {
1227+
%c0 = arith.constant 0: index
1228+
// CHECK: memref.view
1229+
%res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
1230+
return %res : memref<?xi8>
1231+
}

0 commit comments

Comments
 (0)