Skip to content

Commit ad20dc0

Browse files
[mlir][Transforms] Add ApplyConversionAction for profiling purposes (#146208)
Add a new `ApplyConversionAction` so that users can profile the time that is spent in the conversion driver.
1 parent b9b2661 commit ad20dc0

File tree

3 files changed

+50
-15
lines changed

3 files changed

+50
-15
lines changed

mlir/include/mlir/IR/Unit.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Region;
2222
class Block;
2323
class Value;
2424

25-
/// IRUnit is a union of the different types of IR objects that consistute the
25+
/// IRUnit is a union of the different types of IR objects that constitute the
2626
/// IR structure (other than Type and Attribute), that is Operation, Region, and
2727
/// Block.
2828
class IRUnit : public PointerUnion<Operation *, Region *, Block *, Value> {

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,8 +2711,7 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
27112711
}
27122712

27132713
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2714-
if (ops.empty())
2715-
return success();
2714+
assert(!ops.empty() && "expected at least one operation");
27162715
const ConversionTarget &target = opLegalizer.getTarget();
27172716

27182717
// Compute the set of operations and blocks to convert.
@@ -3415,16 +3414,47 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
34153414
// Op Conversion Entry Points
34163415
//===----------------------------------------------------------------------===//
34173416

3417+
/// This is the type of Action that is dispatched when a conversion is applied.
3418+
class ApplyConversionAction
3419+
: public tracing::ActionImpl<ApplyConversionAction> {
3420+
public:
3421+
using Base = tracing::ActionImpl<ApplyConversionAction>;
3422+
ApplyConversionAction(ArrayRef<IRUnit> irUnits) : Base(irUnits) {}
3423+
static constexpr StringLiteral tag = "apply-conversion";
3424+
static constexpr StringLiteral desc =
3425+
"Encapsulate the application of a dialect conversion";
3426+
3427+
void print(raw_ostream &os) const override { os << tag; }
3428+
};
3429+
3430+
static LogicalResult applyConversion(ArrayRef<Operation *> ops,
3431+
const ConversionTarget &target,
3432+
const FrozenRewritePatternSet &patterns,
3433+
ConversionConfig config,
3434+
OpConversionMode mode) {
3435+
if (ops.empty())
3436+
return success();
3437+
MLIRContext *ctx = ops.front()->getContext();
3438+
LogicalResult status = success();
3439+
SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
3440+
ctx->executeAction<ApplyConversionAction>(
3441+
[&] {
3442+
OperationConverter opConverter(target, patterns, config, mode);
3443+
status = opConverter.convertOperations(ops);
3444+
},
3445+
irUnits);
3446+
return status;
3447+
}
3448+
34183449
//===----------------------------------------------------------------------===//
34193450
// Partial Conversion
34203451
//===----------------------------------------------------------------------===//
34213452

34223453
LogicalResult mlir::applyPartialConversion(
34233454
ArrayRef<Operation *> ops, const ConversionTarget &target,
34243455
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3425-
OperationConverter opConverter(target, patterns, config,
3426-
OpConversionMode::Partial);
3427-
return opConverter.convertOperations(ops);
3456+
return applyConversion(ops, target, patterns, config,
3457+
OpConversionMode::Partial);
34283458
}
34293459
LogicalResult
34303460
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
@@ -3441,9 +3471,7 @@ LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
34413471
const ConversionTarget &target,
34423472
const FrozenRewritePatternSet &patterns,
34433473
ConversionConfig config) {
3444-
OperationConverter opConverter(target, patterns, config,
3445-
OpConversionMode::Full);
3446-
return opConverter.convertOperations(ops);
3474+
return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
34473475
}
34483476
LogicalResult mlir::applyFullConversion(Operation *op,
34493477
const ConversionTarget &target,
@@ -3510,9 +3538,8 @@ LogicalResult mlir::applyAnalysisConversion(
35103538
// Convert the cloned operations. The original IR will remain unchanged.
35113539
SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
35123540
ops, [&](Operation *op) { return mapping.lookup(op); });
3513-
OperationConverter opConverter(target, patterns, config,
3514-
OpConversionMode::Analysis);
3515-
LogicalResult status = opConverter.convertOperations(opsToConvert);
3541+
LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
3542+
OpConversionMode::Analysis);
35163543

35173544
// Remap `legalizableOps`, so that they point to the original ops and not the
35183545
// cloned ops.

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s
2-
1+
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics -profile-actions-to=- %s | FileCheck %s
2+
3+
// CHECK: "name": "pass-execution", "cat": "PERF", "ph": "B"
4+
// CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "B"
5+
// CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "B"
6+
// CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "E"
7+
// Note: Listener notifications appear after the pattern application because
8+
// the conversion driver sends all notifications at the end of the conversion
9+
// in bulk.
310
// CHECK: notifyOperationInserted: test.legal_op_a, was unlinked
411
// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a
512
// CHECK-NEXT: notifyOperationModified: func.return
613
// CHECK-NEXT: notifyOperationErased: test.illegal_op_a
7-
14+
// CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "E"
15+
// CHECK: "name": "pass-execution", "cat": "PERF", "ph": "E"
816
// CHECK-LABEL: verifyDirectPattern
917
func.func @verifyDirectPattern() -> i32 {
1018
// CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}

0 commit comments

Comments
 (0)