File tree Expand file tree Collapse file tree 3 files changed +32
-0
lines changed
include/mlir/Dialect/MemRef/IR Expand file tree Collapse file tree 3 files changed +32
-0
lines changed Original file line number Diff line number Diff line change @@ -2333,6 +2333,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
2333
2333
2334
2334
let hasCanonicalizer = 1;
2335
2335
let hasVerifier = 1;
2336
+ let hasFolder = 1;
2336
2337
}
2337
2338
2338
2339
//===----------------------------------------------------------------------===//
Original file line number Diff line number Diff line change @@ -3463,6 +3463,16 @@ LogicalResult ViewOp::verify() {
3463
3463
3464
3464
Value ViewOp::getViewSource () { return getSource (); }
3465
3465
3466
+ OpFoldResult ViewOp::fold (FoldAdaptor adaptor) {
3467
+ MemRefType sourceMemrefType = getSource ().getType ();
3468
+ MemRefType resultMemrefType = getResult ().getType ();
3469
+
3470
+ if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape ())
3471
+ return getViewSource ();
3472
+
3473
+ return {};
3474
+ }
3475
+
3466
3476
namespace {
3467
3477
3468
3478
struct ViewOpShapeFolder : public OpRewritePattern <ViewOp> {
Original file line number Diff line number Diff line change @@ -1208,3 +1208,24 @@ func.func @fold_assume_alignment_chain(%0: memref<128xf32>) -> memref<128xf32> {
1208
1208
// CHECK: return %[[ALIGN]]
1209
1209
return %2 : memref <128 xf32 >
1210
1210
}
1211
+
1212
+ // -----
1213
+
1214
+ // CHECK-LABEL: func @fold_view_same_source_result_types
1215
+ func.func @fold_view_same_source_result_types (%0: memref <128 xi8 >) -> memref <128 xi8 > {
1216
+ %c0 = arith.constant 0 : index
1217
+ // CHECK-NOT: memref.view
1218
+ %res = memref.view %0 [%c0 ][] : memref <128 xi8 > to memref <128 xi8 >
1219
+ return %res : memref <128 xi8 >
1220
+ }
1221
+
1222
+ // -----
1223
+
1224
+ // CHECK-LABEL: func @non_fold_view_same_source_res_types
1225
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1226
+ func.func @non_fold_view_same_source_res_types (%0: memref <?xi8 >, %arg0 : index ) -> memref <?xi8 > {
1227
+ %c0 = arith.constant 0 : index
1228
+ // CHECK: memref.view
1229
+ %res = memref.view %0 [%c0 ][%arg0 ] : memref <?xi8 > to memref <?xi8 >
1230
+ return %res : memref <?xi8 >
1231
+ }
You can’t perform that action at this time.
0 commit comments