From 21a81c7109da2295caf1e2064acf3b6ef3002660 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 28 Jun 2025 10:27:27 +0000 Subject: [PATCH] [mlir][Transforms] Add `ApplyConversionAction` for profiling purposes --- mlir/include/mlir/IR/Unit.h | 2 +- .../Transforms/Utils/DialectConversion.cpp | 49 ++++++++++++++----- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/IR/Unit.h b/mlir/include/mlir/IR/Unit.h index 63117a7664a7d..0e99f8e25f326 100644 --- a/mlir/include/mlir/IR/Unit.h +++ b/mlir/include/mlir/IR/Unit.h @@ -22,7 +22,7 @@ class Region; class Block; class Value; -/// IRUnit is a union of the different types of IR objects that consistute the +/// IRUnit is a union of the different types of IR objects that constitute the /// IR structure (other than Type and Attribute), that is Operation, Region, and /// Block. class IRUnit : public PointerUnion { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 955a106c21941..433f299c1c1e5 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2713,8 +2713,7 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, } LogicalResult OperationConverter::convertOperations(ArrayRef ops) { - if (ops.empty()) - return success(); + assert(!ops.empty() && "expected at least one operation"); const ConversionTarget &target = opLegalizer.getTarget(); // Compute the set of operations and blocks to convert. @@ -3417,6 +3416,38 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { // Op Conversion Entry Points //===----------------------------------------------------------------------===// +/// This is the type of Action that is dispatched when a conversion is applied. +class ApplyConversionAction + : public tracing::ActionImpl { +public: + using Base = tracing::ActionImpl; + ApplyConversionAction(ArrayRef irUnits) : Base(irUnits) {} + static constexpr StringLiteral tag = "apply-conversion"; + static constexpr StringLiteral desc = + "Encapsulate the application of a dialect conversion"; + + void print(raw_ostream &os) const override { os << tag; } +}; + +static LogicalResult applyConversion(ArrayRef ops, + const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config, + OpConversionMode mode) { + if (ops.empty()) + return success(); + MLIRContext *ctx = ops.front()->getContext(); + LogicalResult status = success(); + SmallVector irUnits(ops.begin(), ops.end()); + ctx->executeAction( + [&] { + OperationConverter opConverter(target, patterns, config, mode); + status = opConverter.convertOperations(ops); + }, + irUnits); + return status; +} + //===----------------------------------------------------------------------===// // Partial Conversion //===----------------------------------------------------------------------===// @@ -3424,9 +3455,8 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { LogicalResult mlir::applyPartialConversion( ArrayRef ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config) { - OperationConverter opConverter(target, patterns, config, - OpConversionMode::Partial); - return opConverter.convertOperations(ops); + return applyConversion(ops, target, patterns, config, + OpConversionMode::Partial); } LogicalResult mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, @@ -3443,9 +3473,7 @@ LogicalResult mlir::applyFullConversion(ArrayRef ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config) { - OperationConverter opConverter(target, patterns, config, - OpConversionMode::Full); - return opConverter.convertOperations(ops); + return applyConversion(ops, target, patterns, config, OpConversionMode::Full); } LogicalResult mlir::applyFullConversion(Operation *op, const ConversionTarget &target, @@ -3512,9 +3540,8 @@ LogicalResult mlir::applyAnalysisConversion( // Convert the cloned operations. The original IR will remain unchanged. SmallVector opsToConvert = llvm::map_to_vector( ops, [&](Operation *op) { return mapping.lookup(op); }); - OperationConverter opConverter(target, patterns, config, - OpConversionMode::Analysis); - LogicalResult status = opConverter.convertOperations(opsToConvert); + LogicalResult status = applyConversion(opsToConvert, target, patterns, config, + OpConversionMode::Analysis); // Remap `legalizableOps`, so that they point to the original ops and not the // cloned ops.