From 3b387e7baae4a098ee8bd90ecf1797967c2c43bf Mon Sep 17 00:00:00 2001 From: Zhang Xiang Date: Wed, 18 Jun 2025 17:10:15 +0800 Subject: [PATCH 1/2] [mlir][NFC] Pre-commit test for linalg hoisting --- mlir/test/Dialect/Linalg/hoisting.mlir | 51 ++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 318edca73cce1..67dfe7a2af98b 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -802,3 +802,54 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// Test hoisting of vector.transfer_read/transfer_write pairs with same location +// and this location is marked with assume_align. + +// CHECK-LABEL: func.func @hoist_vector_transfer_read_write() { +// CHECK: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c256 = arith.constant 256 : index +// CHECK-NEXT: %c4096 = arith.constant 4096 : index +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f16 +// CHECK-NEXT: %alloc = memref.alloc() : memref<4096x4096xf16> +// CHECK-NEXT: %alloc_0 = memref.alloc() : memref<4096x4096xf16> +// CHECK-NEXT: %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16> +// CHECK-NEXT: scf.for %arg0 = %c256 to %c4096 step %c256 { +// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> +// CHECK-NEXT: %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> +// CHECK-NEXT: %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> +// CHECK-NEXT: vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16> +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +func.func @hoist_vector_transfer_read_write() { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %cst_0 = arith.constant 0.000000e+00 : f16 + %m0 = memref.alloc() : memref<4096x4096xf16> + %m1 = memref.alloc() : memref<4096x4096xf16> + %assume_align_0 = memref.assume_alignment %m0, 64 : memref<4096x4096xf16> + %assume_align_1 = memref.assume_alignment %m1, 64 : memref<4096x4096xf16> + scf.for %arg0 = %c256 to %c4096 step %c256 { + %1 = vector.transfer_read %assume_align_0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> + %2 = vector.transfer_read %m1[%arg0, %arg0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> + %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %2, %2, %1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + vector.transfer_write %3, %assume_align_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} From 706c1daea66acd1b899aac37a6716a1bf9255c3c Mon Sep 17 00:00:00 2001 From: Zhang Xiang Date: Wed, 18 Jun 2025 14:05:04 +0800 Subject: [PATCH 2/2] [mlir][hoisting] Support memref.assume_alignment in linalg hoisting All ViewLike operations are excluded by hoisting optimization. But assume_alignment just mark memref's alignment, we should check its memref instead of itself. --- .../Dialect/Linalg/Transforms/Hoisting.cpp | 26 +++++++++++++++---- mlir/test/Dialect/Linalg/hoisting.mlir | 11 ++++---- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 707b63ff9335b..b949b06631484 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -199,6 +199,24 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, return true; } +static bool skipViewLike(Operation *source0, Operation *source1) { + bool viewLikeCheck = true; + auto assumeAlignOp = dyn_cast_or_null(source0); + if (assumeAlignOp && source0 == source1) { + Value sourceMemRef = assumeAlignOp.getMemref(); + Operation *sourceOp = sourceMemRef.getDefiningOp(); + return isa_and_nonnull(sourceOp); + } + + if (source0 && isa_and_nonnull(source0)) + return true; + + if (source1 && isa_and_nonnull(source1)) + return true; + + return false; +} + void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip) { bool changed = true; @@ -312,12 +330,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, transferRead.getPermutationMap() != transferWrite.getPermutationMap()) return WalkResult::advance(); - auto *source = transferRead.getBase().getDefiningOp(); - if (source && isa_and_nonnull(source)) - return WalkResult::advance(); + auto *source0 = transferRead.getBase().getDefiningOp(); + auto *source1 = transferWrite.getBase().getDefiningOp(); - source = transferWrite.getBase().getDefiningOp(); - if (source && isa_and_nonnull(source)) + if (skipViewLike(source0, source1)) return WalkResult::advance(); // TODO: may want to memoize this information for performance but it diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 67dfe7a2af98b..c58074e40c5f4 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -816,12 +816,13 @@ module attributes {transform.with_named_sequence} { // CHECK-NEXT: %alloc = memref.alloc() : memref<4096x4096xf16> // CHECK-NEXT: %alloc_0 = memref.alloc() : memref<4096x4096xf16> // CHECK-NEXT: %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16> -// CHECK-NEXT: scf.for %arg0 = %c256 to %c4096 step %c256 { -// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> -// CHECK-NEXT: %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> -// CHECK-NEXT: %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> -// CHECK-NEXT: vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16> +// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> +// CHECK-NEXT: %1 = scf.for %arg0 = %c256 to %c4096 step %c256 iter_args(%arg1 = %0) -> (vector<16x16xf16>) { +// CHECK-NEXT: %2 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16> +// CHECK-NEXT: %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %2, %2, %arg1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> +// CHECK-NEXT: scf.yield %3 : vector<16x16xf16> // CHECK-NEXT: } +// CHECK-NEXT: vector.transfer_write %1, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16> // CHECK-NEXT: return // CHECK-NEXT: }