-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms] Add ApplyConversionAction
for profiling purposes
#146208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][Transforms] Add ApplyConversionAction
for profiling purposes
#146208
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd a new Full diff: https://github.com/llvm/llvm-project/pull/146208.diff 2 Files Affected:
diff --git a/mlir/include/mlir/IR/Unit.h b/mlir/include/mlir/IR/Unit.h
index 63117a7664a7d..0e99f8e25f326 100644
--- a/mlir/include/mlir/IR/Unit.h
+++ b/mlir/include/mlir/IR/Unit.h
@@ -22,7 +22,7 @@ class Region;
class Block;
class Value;
-/// IRUnit is a union of the different types of IR objects that consistute the
+/// IRUnit is a union of the different types of IR objects that constitute the
/// IR structure (other than Type and Attribute), that is Operation, Region, and
/// Block.
class IRUnit : public PointerUnion<Operation *, Region *, Block *, Value> {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 955a106c21941..433f299c1c1e5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2713,8 +2713,7 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
}
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
- if (ops.empty())
- return success();
+ assert(!ops.empty() && "expected at least one operation");
const ConversionTarget &target = opLegalizer.getTarget();
// Compute the set of operations and blocks to convert.
@@ -3417,6 +3416,38 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
+/// This is the type of Action that is dispatched when a conversion is applied.
+class ApplyConversionAction
+ : public tracing::ActionImpl<ApplyConversionAction> {
+public:
+ using Base = tracing::ActionImpl<ApplyConversionAction>;
+ ApplyConversionAction(ArrayRef<IRUnit> irUnits) : Base(irUnits) {}
+ static constexpr StringLiteral tag = "apply-conversion";
+ static constexpr StringLiteral desc =
+ "Encapsulate the application of a dialect conversion";
+
+ void print(raw_ostream &os) const override { os << tag; }
+};
+
+static LogicalResult applyConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config,
+ OpConversionMode mode) {
+ if (ops.empty())
+ return success();
+ MLIRContext *ctx = ops.front()->getContext();
+ LogicalResult status = success();
+ SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
+ ctx->executeAction<ApplyConversionAction>(
+ [&] {
+ OperationConverter opConverter(target, patterns, config, mode);
+ status = opConverter.convertOperations(ops);
+ },
+ irUnits);
+ return status;
+}
+
//===----------------------------------------------------------------------===//
// Partial Conversion
//===----------------------------------------------------------------------===//
@@ -3424,9 +3455,8 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
LogicalResult mlir::applyPartialConversion(
ArrayRef<Operation *> ops, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
- OperationConverter opConverter(target, patterns, config,
- OpConversionMode::Partial);
- return opConverter.convertOperations(ops);
+ return applyConversion(ops, target, patterns, config,
+ OpConversionMode::Partial);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
@@ -3443,9 +3473,7 @@ LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config) {
- OperationConverter opConverter(target, patterns, config,
- OpConversionMode::Full);
- return opConverter.convertOperations(ops);
+ return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
}
LogicalResult mlir::applyFullConversion(Operation *op,
const ConversionTarget &target,
@@ -3512,9 +3540,8 @@ LogicalResult mlir::applyAnalysisConversion(
// Convert the cloned operations. The original IR will remain unchanged.
SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
ops, [&](Operation *op) { return mapping.lookup(op); });
- OperationConverter opConverter(target, patterns, config,
- OpConversionMode::Analysis);
- LogicalResult status = opConverter.convertOperations(opsToConvert);
+ LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
+ OpConversionMode::Analysis);
// Remap `legalizableOps`, so that they point to the original ops and not the
// cloned ops.
|
Would it make sense to add tests for this? |
Add a new
ApplyConversionAction
so that users can profile the time that is spent in the conversion driver.