@@ -23,6 +23,8 @@ namespace bufferization {
23
23
using namespace mlir ;
24
24
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
25
25
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
26
+ using AllocDynamicSizesMap =
27
+ llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
26
28
27
29
// / Return `true` if the given MemRef type has a fully dynamic layout.
28
30
static bool hasFullyDynamicLayoutMap (MemRefType type) {
@@ -43,30 +45,24 @@ static bool hasStaticIdentityLayout(MemRefType type) {
43
45
return type.getLayout ().isIdentity ();
44
46
}
45
47
46
- // / Return the dynamic shapes of the `memref` based on the define op. If the
48
+ // / Return the dynamic shapes of the `memref` based on the defining op. If the
47
49
// / complete dynamic shape fails to be captured, return an empty value.
48
- // / Currently, only function parameters are supported for capturing.
50
+ // / Currently, only function block arguments are supported for capturing.
49
51
static SmallVector<Value> getDynamicSize (Value memref, func::FuncOp funcOp) {
50
- auto *defOp = memref.getDefiningOp ();
52
+ Operation *defOp = memref.getDefiningOp ();
51
53
if (!defOp)
52
54
return {};
53
55
auto operands = defOp->getOperands ();
54
56
SmallVector<Value> dynamicSizes;
55
57
for (Value size : operands) {
56
- BlockArgument sizeSrc = mlir:: dyn_cast<BlockArgument>(size);
58
+ BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
57
59
if (!sizeSrc)
58
60
return {};
59
61
60
- bool finded = false ;
61
- for (BlockArgument argument : funcOp.getArguments ()) {
62
- if (argument == sizeSrc) {
63
- dynamicSizes.push_back (argument);
64
- finded = true ;
65
- break ;
66
- }
67
- }
68
- if (!finded)
62
+ auto iter = llvm::find (funcOp.getArguments (), sizeSrc);
63
+ if (!iter)
69
64
return {};
65
+ dynamicSizes.push_back (*iter);
70
66
}
71
67
return dynamicSizes;
72
68
}
@@ -76,18 +72,20 @@ static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
76
72
static SmallVector<Value> mapDynamicSizeAtCaller (func::CallOp call,
77
73
func::FuncOp callee,
78
74
ValueRange dynamicSizes) {
79
- SmallVector<Value> mapedDynamicSizes ;
75
+ SmallVector<Value> mappedDynamicSizes ;
80
76
for (Value size : dynamicSizes) {
81
77
auto callOperands = call.getOperands ();
82
78
for (size_t i = 0 , e = callOperands.size (); i < e; ++i) {
83
79
Value src = callOperands[i];
84
80
BlockArgument dst = callee.getArgument (i);
85
81
if (size != dst)
86
82
continue ;
87
- mapedDynamicSizes .push_back (src);
83
+ mappedDynamicSizes .push_back (src);
88
84
}
89
85
}
90
- return mapedDynamicSizes;
86
+ assert (mappedDynamicSizes.size () == dynamicSizes.size () &&
87
+ " could not find all dynamic sizes" );
88
+ return mappedDynamicSizes;
91
89
}
92
90
93
91
// Updates the func op and entry block.
@@ -156,7 +154,8 @@ updateFuncOp(func::FuncOp func,
156
154
// the given out-params.
157
155
static LogicalResult
158
156
updateReturnOps (func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
159
- bufferization::BufferResultsToOutParamsOpts &options) {
157
+ AllocDynamicSizesMap &map,
158
+ const bufferization::BufferResultsToOutParamsOpts &options) {
160
159
auto res = func.walk ([&](func::ReturnOp op) {
161
160
SmallVector<Value, 6 > copyIntoOutParams;
162
161
SmallVector<Value, 6 > keepAsReturnOperands;
@@ -171,10 +170,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
171
170
for (auto [orig, arg] : llvm::zip (copyIntoOutParams, appendedEntryArgs)) {
172
171
bool hoistStaticAllocs =
173
172
options.hoistStaticAllocs &&
174
- mlir:: cast<MemRefType>(orig.getType ()).hasStaticShape ();
173
+ cast<MemRefType>(orig.getType ()).hasStaticShape ();
175
174
bool hoistDynamicAllocs =
176
175
options.hoistDynamicAllocs &&
177
- !mlir:: cast<MemRefType>(orig.getType ()).hasStaticShape ();
176
+ !cast<MemRefType>(orig.getType ()).hasStaticShape ();
178
177
if ((hoistStaticAllocs || hoistDynamicAllocs) &&
179
178
isa_and_nonnull<bufferization::AllocationOpInterface>(
180
179
orig.getDefiningOp ())) {
@@ -194,7 +193,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
194
193
auto dynamicSizePair =
195
194
std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
196
195
dynamicSizes);
197
- options. dynamicSizesMap .insert (dynamicSizePair);
196
+ map .insert (dynamicSizePair);
198
197
return WalkResult::advance ();
199
198
});
200
199
return failure (res.wasInterrupted ());
@@ -203,7 +202,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
203
202
// Updates all CallOps in the scope of the given ModuleOp by allocating
204
203
// temporary buffers for newly introduced out params.
205
204
static LogicalResult
206
- updateCalls (ModuleOp module ,
205
+ updateCalls (ModuleOp module , AllocDynamicSizesMap &map,
207
206
const bufferization::BufferResultsToOutParamsOpts &options) {
208
207
bool didFail = false ;
209
208
SymbolTable symtab (module );
@@ -227,8 +226,7 @@ updateCalls(ModuleOp module,
227
226
}
228
227
SmallVector<Value, 6 > outParams;
229
228
OpBuilder builder (op);
230
- SmallVector<SmallVector<Value>> dynamicSizes =
231
- options.dynamicSizesMap .lookup (callee);
229
+ SmallVector<SmallVector<Value>> dynamicSizes = map.lookup (callee);
232
230
size_t dynamicSizesIndex = 0 ;
233
231
for (Value memref : replaceWithOutParams) {
234
232
SmallVector<Value> dynamicSize = dynamicSizes.size () > dynamicSizesIndex
@@ -287,7 +285,11 @@ updateCalls(ModuleOp module,
287
285
}
288
286
289
287
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams (
290
- ModuleOp module , bufferization::BufferResultsToOutParamsOpts &options) {
288
+ ModuleOp module ,
289
+ const bufferization::BufferResultsToOutParamsOpts &options) {
290
+ // / It maps the shape source of the dynamic shape memref returned by each
291
+ // / function.
292
+ AllocDynamicSizesMap map;
291
293
for (auto func : module .getOps <func::FuncOp>()) {
292
294
if (!options.filterFn (&func))
293
295
continue ;
@@ -297,11 +299,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
297
299
return failure ();
298
300
if (func.isExternal ())
299
301
continue ;
300
- if (failed (updateReturnOps (func, appendedEntryArgs, options))) {
302
+ if (failed (updateReturnOps (func, appendedEntryArgs, map, options))) {
301
303
return failure ();
302
304
}
303
305
}
304
- if (failed (updateCalls (module , options)))
306
+ if (failed (updateCalls (module , map, options)))
305
307
return failure ();
306
308
return success ();
307
309
}
0 commit comments