Skip to content

Commit e105cee

Browse files
committed
[mlir] Generalize OneShotModuleBufferize to operate on any Operation
1 parent 3fa07ed commit e105cee

File tree

9 files changed

+172
-56
lines changed

9 files changed

+172
-56
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct LogicalResult;
1414
} // namespace llvm
1515

1616
namespace mlir {
17-
class ModuleOp;
17+
class Operation;
1818

1919
namespace bufferization {
2020
struct BufferizationStatistics;
@@ -23,12 +23,13 @@ struct OneShotBufferizationOptions;
2323
class BufferizationState;
2424

2525
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
26-
/// `state`.
26+
/// `state`. This operates on any `SymbolTable` op.
2727
llvm::LogicalResult
28-
analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
28+
analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state,
2929
BufferizationStatistics *statistics = nullptr);
3030

31-
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
31+
/// Bufferize an `op`s nested ops that implement `BufferizableOpInterface`.
32+
/// This operates on any `SymbolTable` op.
3233
///
3334
/// Note: This function does not run One-Shot Analysis. No buffer copies are
3435
/// inserted except two cases:
@@ -37,20 +38,20 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
3738
/// - `options.copyBeforeWrite` is not set and `options.noAnalysisFuncFilter`
3839
/// is not empty. The FuncOps it contains were not analyzed. Buffer copies
3940
/// will be inserted only to these FuncOps.
40-
llvm::LogicalResult
41-
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
42-
BufferizationState &state,
43-
BufferizationStatistics *statistics = nullptr);
41+
llvm::LogicalResult bufferizeModuleOp(
42+
Operation *moduleOp, const OneShotBufferizationOptions &options,
43+
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
4444

45-
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
46-
void removeBufferizationAttributesInModule(ModuleOp moduleOp);
45+
/// Remove bufferization attributes on every FuncOp arguments in the SymbolTable
46+
/// op.
47+
void removeBufferizationAttributesInModule(Operation *moduleOp);
4748

48-
/// Run One-Shot Module Bufferization on the given module. Performs a simple
49-
/// function call analysis to determine which function arguments are
49+
/// Run One-Shot Module Bufferization on the given SymbolTable. Performs a
50+
/// simple function call analysis to determine which function arguments are
5051
/// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot
5152
/// Bufferize.
5253
llvm::LogicalResult runOneShotModuleBufferize(
53-
ModuleOp moduleOp,
54+
Operation *moduleOp,
5455
const bufferization::OneShotBufferizationOptions &options,
5556
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
5657

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
1+
//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
2+
//----===//
23
//
34
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,12 +9,13 @@
89
//
910
// Module Bufferization is an extension of One-Shot Bufferize that
1011
// bufferizes function boundaries. It provides `BufferizableOpInterface`
11-
// implementations for FuncOp, CallOp and ReturnOp.
12+
// implementations for FuncOp, CallOp and ReturnOp. Although it is named
13+
// Module Bufferization, it may operate on any SymbolTable.
1214
//
13-
// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14-
// This function analyzes the given module and determines the order of analysis
15-
// and bufferization: Functions that are called are processed before their
16-
// respective callers.
15+
// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
16+
// ...)`. This function analyzes the given op and determines the order of
17+
// analysis and bufferization: Functions that are called are processed before
18+
// their respective callers.
1719
//
1820
// After analyzing a FuncOp, additional information about its bbArgs is
1921
// gathered and stored in `FuncAnalysisState`.
@@ -309,34 +311,37 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
309311
/// Return `failure()` if we are unable to retrieve the called FuncOp from
310312
/// any func::CallOp.
311313
static LogicalResult getFuncOpsOrderedByCalls(
312-
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
314+
Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313315
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314316
SymbolTableCollection &symbolTables) {
315317
// For each FuncOp, the set of functions called by it (i.e. the union of
316318
// symbols of all nested func::CallOp).
317319
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
318320
// For each FuncOp, the number of func::CallOp it contains.
319321
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
320-
321-
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
322-
// Collect function calls and populate the caller map.
323-
numberCallOpsContainedInFuncOp[funcOp] = 0;
324-
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
325-
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
326-
assert(calledFunction && "could not retrieved called func::FuncOp");
327-
// If the called function does not have any tensors in its signature, then
328-
// it is not necessary to bufferize the callee before the caller.
329-
if (!hasTensorSignature(calledFunction))
330-
return WalkResult::skip();
331-
332-
callerMap[calledFunction].insert(callOp);
333-
if (calledBy[calledFunction].insert(funcOp).second) {
334-
numberCallOpsContainedInFuncOp[funcOp]++;
322+
for (mlir::Region &region : moduleOp->getRegions()) {
323+
for (mlir::Block &block : region.getBlocks()) {
324+
for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
325+
// Collect function calls and populate the caller map.
326+
numberCallOpsContainedInFuncOp[funcOp] = 0;
327+
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
328+
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
329+
assert(calledFunction && "could not retrieved called func::FuncOp");
330+
// If the called function does not have any tensors in its signature,
331+
// then it is not necessary to bufferize the callee before the caller.
332+
if (!hasTensorSignature(calledFunction))
333+
return WalkResult::skip();
334+
335+
callerMap[calledFunction].insert(callOp);
336+
if (calledBy[calledFunction].insert(funcOp).second) {
337+
numberCallOpsContainedInFuncOp[funcOp]++;
338+
}
339+
return WalkResult::advance();
340+
});
341+
if (res.wasInterrupted())
342+
return failure();
335343
}
336-
return WalkResult::advance();
337-
});
338-
if (res.wasInterrupted())
339-
return failure();
344+
}
340345
}
341346

342347
// Iteratively remove function operations that do not call any of the
@@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
447452
}
448453

449454
LogicalResult
450-
mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
455+
mlir::bufferization::analyzeModuleOp(Operation *moduleOp,
451456
OneShotAnalysisState &state,
452457
BufferizationStatistics *statistics) {
453458
assert(state.getOptions().bufferizeFunctionBoundaries &&
@@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
512517
}
513518

514519
void mlir::bufferization::removeBufferizationAttributesInModule(
515-
ModuleOp moduleOp) {
516-
for (auto op : moduleOp.getOps<func::FuncOp>()) {
517-
for (BlockArgument bbArg : op.getArguments())
518-
removeBufferizationAttributes(bbArg);
520+
Operation *moduleOp) {
521+
for (mlir::Region &region : moduleOp->getRegions()) {
522+
for (mlir::Block &block : region.getBlocks()) {
523+
for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
524+
for (BlockArgument bbArg : funcOp.getArguments())
525+
removeBufferizationAttributes(bbArg);
526+
}
527+
}
519528
}
520529
}
521530

522531
LogicalResult mlir::bufferization::bufferizeModuleOp(
523-
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
532+
Operation *moduleOp, const OneShotBufferizationOptions &options,
524533
BufferizationState &state, BufferizationStatistics *statistics) {
525534
assert(options.bufferizeFunctionBoundaries &&
526535
"expected that function boundary bufferization is activated");
527-
IRRewriter rewriter(moduleOp.getContext());
536+
IRRewriter rewriter(moduleOp->getContext());
528537

529538
// A list of non-circular functions in the order in which they are analyzed
530539
// and bufferized.
@@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
571580
}
572581

573582
// Bufferize all other ops.
574-
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
575-
// Functions were already bufferized.
576-
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
577-
continue;
578-
if (failed(bufferizeOp(&op, options, state, statistics)))
579-
return failure();
583+
for (mlir::Region &region : moduleOp->getRegions()) {
584+
for (mlir::Block &block : region.getBlocks()) {
585+
for (mlir::Operation &op :
586+
llvm::make_early_inc_range(block.getOperations())) {
587+
// Functions were already bufferized.
588+
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
589+
continue;
590+
if (failed(bufferizeOp(&op, options, state, statistics)))
591+
return failure();
592+
}
593+
}
580594
}
581595

582596
// Post-pass cleanup of function argument attributes.
@@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
586600
}
587601

588602
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
589-
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
603+
Operation *moduleOp, const OneShotBufferizationOptions &options,
590604
BufferizationState &state, BufferizationStatistics *statistics) {
591605
assert(options.bufferizeFunctionBoundaries &&
592606
"expected that function boundary bufferization is activated");

mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
3535
// analysis depending on whether function boundary bufferization is enabled or
3636
// not.
3737
if (options.bufferizeFunctionBoundaries) {
38-
if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
38+
if (failed(analyzeModuleOp(op, analysisState, statistics)))
3939
return failure();
4040
} else {
4141
if (failed(analyzeOp(op, analysisState, statistics)))

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ class SparsificationAndBufferizationPass
115115

116116
bufferization::BufferizationState bufferizationState;
117117

118-
if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
119-
updatedOptions,
118+
if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions,
120119
bufferizationState)))
121120
return failure();
122121

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s
2+
3+
"test.symbol_scope_isolated"() ({
4+
// CHECK-LABEL: func @inner_func(
5+
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
6+
func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
7+
// CHECK-NOT: copy
8+
%f = arith.constant 1.0 : f32
9+
%c0 = arith.constant 0 : index
10+
%c1 = arith.constant 1 : index
11+
// CHECK: memref.store %{{.*}}, %[[arg0]]
12+
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
13+
// CHECK: %[[load:.*]] = memref.load %[[arg0]]
14+
%1 = tensor.extract %0[%c1] : tensor<?xf32>
15+
// CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32
16+
return %0, %1 : tensor<?xf32>, f32
17+
}
18+
19+
// CHECK-LABEL: func @call_func_with_non_tensor_return(
20+
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
21+
func.func @call_func_with_non_tensor_return(
22+
%t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) {
23+
// CHECK-NOT: alloc
24+
// CHECK-NOT: copy
25+
// CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
26+
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
27+
// CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
28+
return %1, %0 : f32, tensor<?xf32>
29+
}
30+
"test.finish" () : () -> ()
31+
}) : () -> ()
32+
33+

mlir/test/lib/Dialect/Bufferization/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRBufferizationTestPasses
3+
TestOneShotModuleBufferize.cpp
34
TestTensorCopyInsertion.cpp
45
TestTensorLikeAndBufferLike.cpp
56

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===- TestOneShotModuleBufferzation.cpp - Bufferization Test -----*- c++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11+
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
12+
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
13+
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
14+
#include "mlir/Pass/Pass.h"
15+
16+
using namespace mlir;
17+
18+
namespace {
19+
struct TestOneShotModuleBufferizePass
20+
: public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> {
21+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass)
22+
23+
TestOneShotModuleBufferizePass() = default;
24+
TestOneShotModuleBufferizePass(const TestOneShotModuleBufferizePass &pass)
25+
: PassWrapper(pass) {}
26+
27+
void getDependentDialects(DialectRegistry &registry) const override {
28+
registry.insert<bufferization::BufferizationDialect>();
29+
}
30+
StringRef getArgument() const final {
31+
return "test-one-shot-module-bufferize";
32+
}
33+
StringRef getDescription() const final {
34+
return "Pass to test One Shot Module Bufferization";
35+
}
36+
37+
void runOnOperation() override {
38+
39+
llvm::errs() << "Running TestOneShotModuleBufferize on: "
40+
<< getOperation()->getName() << "\n";
41+
bufferization::OneShotBufferizationOptions opt;
42+
43+
opt.bufferizeFunctionBoundaries = true;
44+
bufferization::BufferizationState bufferizationState;
45+
46+
if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt,
47+
bufferizationState)))
48+
signalPassFailure();
49+
}
50+
};
51+
} // namespace
52+
53+
namespace mlir::test {
54+
void registerTestOneShotModuleBufferizePass() {
55+
PassRegistration<TestOneShotModuleBufferizePass>();
56+
}
57+
} // namespace mlir::test

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ def SymbolScopeOp : TEST_Op<"symbol_scope",
125125
let regions = (region SizedRegion<1>:$region);
126126
}
127127

128+
def SymbolScopeIsolatedOp
129+
: TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable,
130+
SingleBlockImplicitTerminator<
131+
"TerminatorOp">]> {
132+
let summary =
133+
"operation which defines a new symbol table that is IsolatedFromAbove";
134+
let regions = (region SizedRegion<1>:$region);
135+
}
136+
128137
def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> {
129138
let summary = "operation which defines a new symbol table without a "
130139
"restriction on a terminator";

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ void registerTestMeshSimplificationsPass();
135135
void registerTestMultiBuffering();
136136
void registerTestNextAccessPass();
137137
void registerTestNVGPULowerings();
138+
void registerTestOneShotModuleBufferizePass();
138139
void registerTestOpaqueLoc();
139140
void registerTestOpLoweringPasses();
140141
void registerTestPadFusion();
@@ -281,6 +282,7 @@ void registerTestPasses() {
281282
mlir::test::registerTestMultiBuffering();
282283
mlir::test::registerTestNextAccessPass();
283284
mlir::test::registerTestNVGPULowerings();
285+
mlir::test::registerTestOneShotModuleBufferizePass();
284286
mlir::test::registerTestOpaqueLoc();
285287
mlir::test::registerTestOpLoweringPasses();
286288
mlir::test::registerTestPadFusion();

0 commit comments

Comments
 (0)