1
- // ===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
1
+ // ===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
2
+ // ----===//
2
3
//
3
4
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
5
// See https://llvm.org/LICENSE.txt for license information.
8
9
//
9
10
// Module Bufferization is an extension of One-Shot Bufferize that
10
11
// bufferizes function boundaries. It provides `BufferizableOpInterface`
11
- // implementations for FuncOp, CallOp and ReturnOp.
12
+ // implementations for FuncOp, CallOp and ReturnOp. Although it is named
13
+ // Module Bufferization, it may operate on any SymbolTable.
12
14
//
13
- // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14
- // This function analyzes the given module and determines the order of analysis
15
- // and bufferization: Functions that are called are processed before their
16
- // respective callers.
15
+ // Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
16
+ // ...)`. This function analyzes the given op and determines the order of
17
+ // analysis and bufferization: Functions that are called are processed before
18
+ // their respective callers.
17
19
//
18
20
// After analyzing a FuncOp, additional information about its bbArgs is
19
21
// gathered and stored in `FuncAnalysisState`.
@@ -309,34 +311,37 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
309
311
// / Return `failure()` if we are unable to retrieve the called FuncOp from
310
312
// / any func::CallOp.
311
313
static LogicalResult getFuncOpsOrderedByCalls (
312
- ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
314
+ Operation * moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313
315
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314
316
SymbolTableCollection &symbolTables) {
315
317
// For each FuncOp, the set of functions called by it (i.e. the union of
316
318
// symbols of all nested func::CallOp).
317
319
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
318
320
// For each FuncOp, the number of func::CallOp it contains.
319
321
DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
320
-
321
- for (func::FuncOp funcOp : moduleOp.getOps <func::FuncOp>()) {
322
- // Collect function calls and populate the caller map.
323
- numberCallOpsContainedInFuncOp[funcOp] = 0 ;
324
- WalkResult res = funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
325
- func::FuncOp calledFunction = getCalledFunction (callOp, symbolTables);
326
- assert (calledFunction && " could not retrieved called func::FuncOp" );
327
- // If the called function does not have any tensors in its signature, then
328
- // it is not necessary to bufferize the callee before the caller.
329
- if (!hasTensorSignature (calledFunction))
330
- return WalkResult::skip ();
331
-
332
- callerMap[calledFunction].insert (callOp);
333
- if (calledBy[calledFunction].insert (funcOp).second ) {
334
- numberCallOpsContainedInFuncOp[funcOp]++;
322
+ for (mlir::Region ®ion : moduleOp->getRegions ()) {
323
+ for (mlir::Block &block : region.getBlocks ()) {
324
+ for (func::FuncOp funcOp : block.getOps <func::FuncOp>()) {
325
+ // Collect function calls and populate the caller map.
326
+ numberCallOpsContainedInFuncOp[funcOp] = 0 ;
327
+ WalkResult res = funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
328
+ func::FuncOp calledFunction = getCalledFunction (callOp, symbolTables);
329
+ assert (calledFunction && " could not retrieved called func::FuncOp" );
330
+ // If the called function does not have any tensors in its signature,
331
+ // then it is not necessary to bufferize the callee before the caller.
332
+ if (!hasTensorSignature (calledFunction))
333
+ return WalkResult::skip ();
334
+
335
+ callerMap[calledFunction].insert (callOp);
336
+ if (calledBy[calledFunction].insert (funcOp).second ) {
337
+ numberCallOpsContainedInFuncOp[funcOp]++;
338
+ }
339
+ return WalkResult::advance ();
340
+ });
341
+ if (res.wasInterrupted ())
342
+ return failure ();
335
343
}
336
- return WalkResult::advance ();
337
- });
338
- if (res.wasInterrupted ())
339
- return failure ();
344
+ }
340
345
}
341
346
342
347
// Iteratively remove function operations that do not call any of the
@@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
447
452
}
448
453
449
454
LogicalResult
450
- mlir::bufferization::analyzeModuleOp (ModuleOp moduleOp,
455
+ mlir::bufferization::analyzeModuleOp (Operation * moduleOp,
451
456
OneShotAnalysisState &state,
452
457
BufferizationStatistics *statistics) {
453
458
assert (state.getOptions ().bufferizeFunctionBoundaries &&
@@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
512
517
}
513
518
514
519
void mlir::bufferization::removeBufferizationAttributesInModule (
515
- ModuleOp moduleOp) {
516
- for (auto op : moduleOp.getOps <func::FuncOp>()) {
517
- for (BlockArgument bbArg : op.getArguments ())
518
- removeBufferizationAttributes (bbArg);
520
+ Operation *moduleOp) {
521
+ for (mlir::Region ®ion : moduleOp->getRegions ()) {
522
+ for (mlir::Block &block : region.getBlocks ()) {
523
+ for (func::FuncOp funcOp : block.getOps <func::FuncOp>()) {
524
+ for (BlockArgument bbArg : funcOp.getArguments ())
525
+ removeBufferizationAttributes (bbArg);
526
+ }
527
+ }
519
528
}
520
529
}
521
530
522
531
LogicalResult mlir::bufferization::bufferizeModuleOp (
523
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
532
+ Operation * moduleOp, const OneShotBufferizationOptions &options,
524
533
BufferizationState &state, BufferizationStatistics *statistics) {
525
534
assert (options.bufferizeFunctionBoundaries &&
526
535
" expected that function boundary bufferization is activated" );
527
- IRRewriter rewriter (moduleOp. getContext ());
536
+ IRRewriter rewriter (moduleOp-> getContext ());
528
537
529
538
// A list of non-circular functions in the order in which they are analyzed
530
539
// and bufferized.
@@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
571
580
}
572
581
573
582
// Bufferize all other ops.
574
- for (Operation &op : llvm::make_early_inc_range (moduleOp.getOps ())) {
575
- // Functions were already bufferized.
576
- if (isa<func::FuncOp>(&op) || op.hasTrait <OpTrait::SymbolTable>())
577
- continue ;
578
- if (failed (bufferizeOp (&op, options, state, statistics)))
579
- return failure ();
583
+ for (mlir::Region ®ion : moduleOp->getRegions ()) {
584
+ for (mlir::Block &block : region.getBlocks ()) {
585
+ for (mlir::Operation &op :
586
+ llvm::make_early_inc_range (block.getOperations ())) {
587
+ // Functions were already bufferized.
588
+ if (isa<func::FuncOp>(&op) || op.hasTrait <OpTrait::SymbolTable>())
589
+ continue ;
590
+ if (failed (bufferizeOp (&op, options, state, statistics)))
591
+ return failure ();
592
+ }
593
+ }
580
594
}
581
595
582
596
// Post-pass cleanup of function argument attributes.
@@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
586
600
}
587
601
588
602
LogicalResult mlir::bufferization::runOneShotModuleBufferize (
589
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
603
+ Operation * moduleOp, const OneShotBufferizationOptions &options,
590
604
BufferizationState &state, BufferizationStatistics *statistics) {
591
605
assert (options.bufferizeFunctionBoundaries &&
592
606
" expected that function boundary bufferization is activated" );
0 commit comments