Skip to content

Commit d4d95da

Browse files
committed
fixup! fixup! fixup! fixup! [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases
Extra test
1 parent a991957 commit d4d95da

File tree

1 file changed

+54
-10
lines changed

1 file changed

+54
-10
lines changed

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
///----------------------------------------------------------------------------------------
44
/// Tests for vector.transfer_read + vector.transfer_write pairs
55
///
6-
/// * Indices are static
7-
/// * Single loop
6+
/// * Nested inside a single loop
7+
// * Indices are constant
88
///----------------------------------------------------------------------------------------
99

1010
// The most basic example - hoisting is safe.
@@ -23,13 +23,13 @@ func.func @hoist_basic_vector_xfer_pair(
2323
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
2424
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
2525
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
26-
// CHECK: %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
26+
// CHECK: %[[VAL_6:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
2727
// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
2828
// CHECK: }
2929
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
3030
scf.for %i = %lb to %ub step %step {
3131
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
32-
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
32+
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
3333
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
3434
}
3535
return
@@ -66,15 +66,15 @@ func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
6666
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
6767
// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
6868
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
69-
// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
69+
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
7070
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
7171
// CHECK: }
7272

7373
scf.for %i = %lb to %ub step %step {
7474
vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
7575

7676
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
77-
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
77+
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
7878
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
7979
}
8080
return
@@ -113,7 +113,7 @@ func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
113113
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
114114
// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
115115
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
116-
// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
116+
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
117117
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
118118
// CHECK: }
119119

@@ -122,7 +122,7 @@ func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
122122
vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
123123

124124
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
125-
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
125+
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
126126
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
127127
}
128128
return
@@ -160,14 +160,14 @@ func.func @hoist_basic_vector_xfer_pair_with_assume_align(
160160
// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
161161
// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
162162
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
163-
// CHECK: %[[USE:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
163+
// CHECK: %[[USE:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
164164
// CHECK: }
165165
// CHECK: vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
166166

167167
%aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
168168
scf.for %i = %lb to %ub step %step {
169169
%r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
170-
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
170+
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
171171
vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
172172
}
173173
return
@@ -185,6 +185,50 @@ module attributes {transform.with_named_sequence} {
185185

186186
// -----
187187

188+
// Similar as the example above, but hoisting is not safe due to extra memory
189+
// access inside the loop via the original memref.
190+
191+
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
192+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
193+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
194+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
195+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
196+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
197+
func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
198+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
199+
%c0 = arith.constant 0 : index
200+
%pad = arith.constant 0.0 : f32
201+
202+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
203+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
204+
// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
205+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
206+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
207+
// CHECK: "mem_use"(%[[MEM]])
208+
// CHECK: vector.transfer_write %[[READ]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
209+
// CHECK: }
210+
211+
%aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
212+
scf.for %i = %lb to %ub step %step {
213+
%r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
214+
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
215+
vector.transfer_write %r0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
216+
}
217+
return
218+
}
219+
220+
module attributes {transform.with_named_sequence} {
221+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
222+
%0 = transform.structured.match ops{["func.func"]} in %arg1
223+
: (!transform.any_op) -> !transform.any_op
224+
transform.structured.hoist_redundant_vector_transfers %0
225+
: (!transform.any_op) -> !transform.any_op
226+
transform.yield
227+
}
228+
}
229+
230+
// -----
231+
188232
///----------------------------------------------------------------------------------------
189233
/// Tests for vector.transfer_read + vector.transfer_write pairs
190234
///

0 commit comments

Comments
 (0)