Skip to content

Commit 100dfcc

Browse files
update.
1 parent 3484969 commit 100dfcc

File tree

2 files changed

+29
-31
lines changed

2 files changed

+29
-31
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,13 @@ struct BufferResultsToOutParamsOpts {
171171
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
172172
/// memref is dynamic allocated in the current function.
173173
bool hoistDynamicAllocs = false;
174-
175-
/// It maps the shape source of the dynamic shape memref returned by each
176-
/// function.
177-
llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>> dynamicSizesMap;
178174
};
179175

180176
/// Replace buffers that are returned from a function with an out parameter.
181177
/// Also update all call sites.
182178
LogicalResult
183179
promoteBufferResultsToOutParams(ModuleOp module,
184-
BufferResultsToOutParamsOpts &options);
180+
const BufferResultsToOutParamsOpts &options);
185181

186182
/// Drop all memref function results that are equivalent to a function argument.
187183
LogicalResult dropEquivalentBufferResults(ModuleOp module);

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ namespace bufferization {
2323
using namespace mlir;
2424
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
2525
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
26+
using AllocDynamicSizesMap =
27+
llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
2628

2729
/// Return `true` if the given MemRef type has a fully dynamic layout.
2830
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -43,30 +45,24 @@ static bool hasStaticIdentityLayout(MemRefType type) {
4345
return type.getLayout().isIdentity();
4446
}
4547

46-
/// Return the dynamic shapes of the `memref` based on the define op. If the
48+
/// Return the dynamic shapes of the `memref` based on the defining op. If the
4749
/// complete dynamic shape fails to be captured, return an empty value.
48-
/// Currently, only function parameters are supported for capturing.
50+
/// Currently, only function block arguments are supported for capturing.
4951
static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
50-
auto *defOp = memref.getDefiningOp();
52+
Operation *defOp = memref.getDefiningOp();
5153
if (!defOp)
5254
return {};
5355
auto operands = defOp->getOperands();
5456
SmallVector<Value> dynamicSizes;
5557
for (Value size : operands) {
56-
BlockArgument sizeSrc = mlir::dyn_cast<BlockArgument>(size);
58+
BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
5759
if (!sizeSrc)
5860
return {};
5961

60-
bool finded = false;
61-
for (BlockArgument argument : funcOp.getArguments()) {
62-
if (argument == sizeSrc) {
63-
dynamicSizes.push_back(argument);
64-
finded = true;
65-
break;
66-
}
67-
}
68-
if (!finded)
62+
auto iter = llvm::find(funcOp.getArguments(), sizeSrc);
63+
if (!iter)
6964
return {};
65+
dynamicSizes.push_back(*iter);
7066
}
7167
return dynamicSizes;
7268
}
@@ -76,18 +72,20 @@ static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
7672
static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
7773
func::FuncOp callee,
7874
ValueRange dynamicSizes) {
79-
SmallVector<Value> mapedDynamicSizes;
75+
SmallVector<Value> mappedDynamicSizes;
8076
for (Value size : dynamicSizes) {
8177
auto callOperands = call.getOperands();
8278
for (size_t i = 0, e = callOperands.size(); i < e; ++i) {
8379
Value src = callOperands[i];
8480
BlockArgument dst = callee.getArgument(i);
8581
if (size != dst)
8682
continue;
87-
mapedDynamicSizes.push_back(src);
83+
mappedDynamicSizes.push_back(src);
8884
}
8985
}
90-
return mapedDynamicSizes;
86+
assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
87+
"could not find all dynamic sizes");
88+
return mappedDynamicSizes;
9189
}
9290

9391
// Updates the func op and entry block.
@@ -156,7 +154,8 @@ updateFuncOp(func::FuncOp func,
156154
// the given out-params.
157155
static LogicalResult
158156
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
159-
bufferization::BufferResultsToOutParamsOpts &options) {
157+
AllocDynamicSizesMap &map,
158+
const bufferization::BufferResultsToOutParamsOpts &options) {
160159
auto res = func.walk([&](func::ReturnOp op) {
161160
SmallVector<Value, 6> copyIntoOutParams;
162161
SmallVector<Value, 6> keepAsReturnOperands;
@@ -171,10 +170,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
171170
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
172171
bool hoistStaticAllocs =
173172
options.hoistStaticAllocs &&
174-
mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
173+
cast<MemRefType>(orig.getType()).hasStaticShape();
175174
bool hoistDynamicAllocs =
176175
options.hoistDynamicAllocs &&
177-
!mlir::cast<MemRefType>(orig.getType()).hasStaticShape();
176+
!cast<MemRefType>(orig.getType()).hasStaticShape();
178177
if ((hoistStaticAllocs || hoistDynamicAllocs) &&
179178
isa_and_nonnull<bufferization::AllocationOpInterface>(
180179
orig.getDefiningOp())) {
@@ -194,7 +193,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
194193
auto dynamicSizePair =
195194
std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
196195
dynamicSizes);
197-
options.dynamicSizesMap.insert(dynamicSizePair);
196+
map.insert(dynamicSizePair);
198197
return WalkResult::advance();
199198
});
200199
return failure(res.wasInterrupted());
@@ -203,7 +202,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
203202
// Updates all CallOps in the scope of the given ModuleOp by allocating
204203
// temporary buffers for newly introduced out params.
205204
static LogicalResult
206-
updateCalls(ModuleOp module,
205+
updateCalls(ModuleOp module, AllocDynamicSizesMap &map,
207206
const bufferization::BufferResultsToOutParamsOpts &options) {
208207
bool didFail = false;
209208
SymbolTable symtab(module);
@@ -227,8 +226,7 @@ updateCalls(ModuleOp module,
227226
}
228227
SmallVector<Value, 6> outParams;
229228
OpBuilder builder(op);
230-
SmallVector<SmallVector<Value>> dynamicSizes =
231-
options.dynamicSizesMap.lookup(callee);
229+
SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee);
232230
size_t dynamicSizesIndex = 0;
233231
for (Value memref : replaceWithOutParams) {
234232
SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
@@ -287,7 +285,11 @@ updateCalls(ModuleOp module,
287285
}
288286

289287
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
290-
ModuleOp module, bufferization::BufferResultsToOutParamsOpts &options) {
288+
ModuleOp module,
289+
const bufferization::BufferResultsToOutParamsOpts &options) {
290+
/// It maps the shape source of the dynamic shape memref returned by each
291+
/// function.
292+
AllocDynamicSizesMap map;
291293
for (auto func : module.getOps<func::FuncOp>()) {
292294
if (!options.filterFn(&func))
293295
continue;
@@ -297,11 +299,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
297299
return failure();
298300
if (func.isExternal())
299301
continue;
300-
if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
302+
if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
301303
return failure();
302304
}
303305
}
304-
if (failed(updateCalls(module, options)))
306+
if (failed(updateCalls(module, map, options)))
305307
return failure();
306308
return success();
307309
}

0 commit comments

Comments
 (0)