Skip to content

Commit 541f33e

Browse files
authored
[mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases (#145235)
This patch adds additional checks to the hoisting logic to prevent hoisting of `vector.transfer_read` / `vector.transfer_write` pairs when the underlying memref has users that introduce aliases via operations implementing `ViewLikeOpInterface`. Note: This may conservatively block some valid hoisting opportunities and could affect performance. However, as demonstrated by the included tests, the current logic is too permissive and can lead to incorrect transformations. If this change prevents hoisting in cases that are provably safe, please share a minimal repro - I'm happy to explore ways to relax the check. Special treatment is given to `memref.assume_alignment`, mainly to accommodate recent updates in: * #139521 Note that such special casing does not scale and should generally be avoided. The current hoisting logic lacks robust alias analysis. While better support would require more work, the broader semantics of `memref.assume_alignment` remain somewhat unclear. It's possible this op may eventually be replaced with the "alignment" attribute added in: * #144344
1 parent 7e2e030 commit 541f33e

File tree

2 files changed

+265
-7
lines changed

2 files changed

+265
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -303,23 +303,51 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
303303
// 1. indices, vector type and permutation map are the same (i.e., the
304304
// transfer_read/transfer_write ops are matching),
305305
// 2. source operands for transfer.{read|write} do not originate from
306-
// Ops implementing ViewLikeOpInterface.
306+
// nor have users that are Ops implementing ViewLikeOpInterface.
307307
// 3. no other operations in the loop access the same memref except
308308
// for transfer_read/transfer_write accessing statically disjoint
309309
// slices.
310+
311+
// Check 1.
310312
if (transferRead.getIndices() != transferWrite.getIndices() ||
311313
transferRead.getVectorType() != transferWrite.getVectorType() ||
312314
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
313315
return WalkResult::advance();
314316

315-
auto *source = transferRead.getBase().getDefiningOp();
316-
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
317-
return WalkResult::advance();
317+
// Check 2. Note, since both xfer Ops share the source, we only need to
318+
// look at one of them.
319+
auto base = transferRead.getBase();
320+
auto *source = base.getDefiningOp();
321+
if (source) {
322+
// NOTE: We treat `memref.assume_alignment` as a special case.
323+
//
324+
// The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
325+
// MemRef _before_ alignment) iff:
326+
// 1. It has exactly two uses (these have to be the xfer Ops
327+
// being looked at).
328+
// 2. The original MemRef has only one use (i.e.
329+
// AssumeAlignmentOp).
330+
//
331+
// Relaxing these conditions will most likely require proper alias
332+
// analysis.
333+
if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
334+
Value memPreAlignment = assume.getMemref();
335+
auto numInLoopUses =
336+
llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
337+
return loop->isAncestor(use.getOwner());
338+
});
339+
340+
if (numInLoopUses && memPreAlignment.hasOneUse())
341+
source = memPreAlignment.getDefiningOp();
342+
}
343+
if (isa_and_nonnull<ViewLikeOpInterface>(source))
344+
return WalkResult::advance();
345+
}
318346

319-
source = transferWrite.getBase().getDefiningOp();
320-
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
347+
if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
321348
return WalkResult::advance();
322349

350+
// Check 3.
323351
// TODO: may want to memoize this information for performance but it
324352
// likely gets invalidated often.
325353
DominanceInfo dom(loop);
@@ -358,7 +386,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
358386
// Hoist write after.
359387
transferWrite->moveAfter(loop);
360388

361-
// Rewrite `loop` with new yields by cloning and erase the original loop.
389+
// Rewrite `loop` with new yields by cloning and erase the original
390+
// loop.
362391
IRRewriter rewriter(transferRead.getContext());
363392
NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
364393
ArrayRef<BlockArgument> newBBArgs) {

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,234 @@
11
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
22

3+
///----------------------------------------------------------------------------------------
4+
/// Tests for vector.transfer_read + vector.transfer_write pairs
5+
///
6+
/// * Nested inside a single loop
7+
// * Indices are constant
8+
///----------------------------------------------------------------------------------------
9+
10+
// The most basic example - hoisting is safe.
11+
12+
// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair(
13+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
14+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
15+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
16+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) {
17+
func.func @hoist_basic_vector_xfer_pair(
18+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
19+
%c0 = arith.constant 0 : index
20+
%pad = arith.constant 0.0 : f32
21+
22+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
23+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
24+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
25+
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
26+
// CHECK: %[[VAL_6:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
27+
// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
28+
// CHECK: }
29+
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
30+
scf.for %i = %lb to %ub step %step {
31+
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
32+
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
33+
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
34+
}
35+
return
36+
}
37+
38+
module attributes {transform.with_named_sequence} {
39+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
40+
%0 = transform.structured.match ops{["func.func"]} in %arg1
41+
: (!transform.any_op) -> !transform.any_op
42+
transform.structured.hoist_redundant_vector_transfers %0
43+
: (!transform.any_op) -> !transform.any_op
44+
transform.yield
45+
}
46+
}
47+
48+
// -----
49+
50+
// Similar as the example above, but hoisting is no longer safe. That's due to
51+
// an extra xfer_write inside the loop.
52+
53+
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
54+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
55+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
56+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
57+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
58+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
59+
func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
60+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
61+
%c0 = arith.constant 0 : index
62+
%pad = arith.constant 0.0 : f32
63+
64+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
65+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
66+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
67+
// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
68+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
69+
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
70+
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
71+
// CHECK: }
72+
73+
scf.for %i = %lb to %ub step %step {
74+
vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
75+
76+
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
77+
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
78+
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
79+
}
80+
return
81+
}
82+
83+
module attributes {transform.with_named_sequence} {
84+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
85+
%0 = transform.structured.match ops{["func.func"]} in %arg1
86+
: (!transform.any_op) -> !transform.any_op
87+
transform.structured.hoist_redundant_vector_transfers %0
88+
: (!transform.any_op) -> !transform.any_op
89+
transform.yield
90+
}
91+
}
92+
93+
// -----
94+
95+
// Similar as the example above, but hoisting is no longer safe. That's due to
96+
// an extra xfer_write into _an alias_ of the %mem Op that is used by the
97+
// original xfer pair.
98+
99+
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
100+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
101+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
102+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
103+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
104+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
105+
func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
106+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
107+
%c0 = arith.constant 0 : index
108+
%pad = arith.constant 0.0 : f32
109+
110+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
111+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
112+
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
113+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
114+
// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
115+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
116+
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
117+
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
118+
// CHECK: }
119+
120+
%sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
121+
scf.for %i = %lb to %ub step %step {
122+
vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
123+
124+
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
125+
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
126+
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
127+
}
128+
return
129+
}
130+
131+
module attributes {transform.with_named_sequence} {
132+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
133+
%0 = transform.structured.match ops{["func.func"]} in %arg1
134+
: (!transform.any_op) -> !transform.any_op
135+
transform.structured.hoist_redundant_vector_transfers %0
136+
: (!transform.any_op) -> !transform.any_op
137+
transform.yield
138+
}
139+
}
140+
141+
// -----
142+
143+
// Similar as the example above, but the memory access is done via
144+
// memref.assume_alignment. Hoisting is safe as the only users of the
145+
// "allignment" Op are the xfer Ops within the loop that we want to hoist.
146+
147+
// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair_with_assume_align(
148+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
149+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
150+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
151+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
152+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
153+
func.func @hoist_basic_vector_xfer_pair_with_assume_align(
154+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
155+
%c0 = arith.constant 0 : index
156+
%pad = arith.constant 0.0 : f32
157+
158+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
159+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
160+
// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
161+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
162+
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
163+
// CHECK: %[[USE:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
164+
// CHECK: }
165+
// CHECK: vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
166+
167+
%aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
168+
scf.for %i = %lb to %ub step %step {
169+
%r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
170+
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
171+
vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
172+
}
173+
return
174+
}
175+
176+
module attributes {transform.with_named_sequence} {
177+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
178+
%0 = transform.structured.match ops{["func.func"]} in %arg1
179+
: (!transform.any_op) -> !transform.any_op
180+
transform.structured.hoist_redundant_vector_transfers %0
181+
: (!transform.any_op) -> !transform.any_op
182+
transform.yield
183+
}
184+
}
185+
186+
// -----
187+
188+
// Similar as the example above, but hoisting is not safe due to extra memory
189+
// access inside the loop via the original memref.
190+
191+
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
192+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
193+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
194+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
195+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
196+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
197+
func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
198+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
199+
%c0 = arith.constant 0 : index
200+
%pad = arith.constant 0.0 : f32
201+
202+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
203+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
204+
// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
205+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
206+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
207+
// CHECK: "mem_use"(%[[MEM]])
208+
// CHECK: vector.transfer_write %[[READ]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
209+
// CHECK: }
210+
211+
%aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
212+
scf.for %i = %lb to %ub step %step {
213+
%r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
214+
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
215+
vector.transfer_write %r0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
216+
}
217+
return
218+
}
219+
220+
module attributes {transform.with_named_sequence} {
221+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
222+
%0 = transform.structured.match ops{["func.func"]} in %arg1
223+
: (!transform.any_op) -> !transform.any_op
224+
transform.structured.hoist_redundant_vector_transfers %0
225+
: (!transform.any_op) -> !transform.any_op
226+
transform.yield
227+
}
228+
}
229+
230+
// -----
231+
3232
///----------------------------------------------------------------------------------------
4233
/// Tests for vector.transfer_read + vector.transfer_write pairs
5234
///

0 commit comments

Comments
 (0)