Skip to content

[mlir][Transforms] Add ApplyConversionAction for profiling purposes #146208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/Unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Region;
class Block;
class Value;

/// IRUnit is a union of the different types of IR objects that consistute the
/// IRUnit is a union of the different types of IR objects that constitute the
/// IR structure (other than Type and Attribute), that is Operation, Region, and
/// Block.
class IRUnit : public PointerUnion<Operation *, Region *, Block *, Value> {
Expand Down
49 changes: 38 additions & 11 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2713,8 +2713,7 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
}

LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
assert(!ops.empty() && "expected at least one operation");
const ConversionTarget &target = opLegalizer.getTarget();

// Compute the set of operations and blocks to convert.
Expand Down Expand Up @@ -3417,16 +3416,47 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//

/// This is the type of Action that is dispatched when a conversion is applied.
class ApplyConversionAction
: public tracing::ActionImpl<ApplyConversionAction> {
public:
using Base = tracing::ActionImpl<ApplyConversionAction>;
ApplyConversionAction(ArrayRef<IRUnit> irUnits) : Base(irUnits) {}
static constexpr StringLiteral tag = "apply-conversion";
static constexpr StringLiteral desc =
"Encapsulate the application of a dialect conversion";

void print(raw_ostream &os) const override { os << tag; }
};

static LogicalResult applyConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config,
OpConversionMode mode) {
if (ops.empty())
return success();
MLIRContext *ctx = ops.front()->getContext();
LogicalResult status = success();
SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
ctx->executeAction<ApplyConversionAction>(
[&] {
OperationConverter opConverter(target, patterns, config, mode);
status = opConverter.convertOperations(ops);
},
irUnits);
return status;
}

//===----------------------------------------------------------------------===//
// Partial Conversion
//===----------------------------------------------------------------------===//

LogicalResult mlir::applyPartialConversion(
ArrayRef<Operation *> ops, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
OperationConverter opConverter(target, patterns, config,
OpConversionMode::Partial);
return opConverter.convertOperations(ops);
return applyConversion(ops, target, patterns, config,
OpConversionMode::Partial);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
Expand All @@ -3443,9 +3473,7 @@ LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config) {
OperationConverter opConverter(target, patterns, config,
OpConversionMode::Full);
return opConverter.convertOperations(ops);
return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
}
LogicalResult mlir::applyFullConversion(Operation *op,
const ConversionTarget &target,
Expand Down Expand Up @@ -3512,9 +3540,8 @@ LogicalResult mlir::applyAnalysisConversion(
// Convert the cloned operations. The original IR will remain unchanged.
SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
ops, [&](Operation *op) { return mapping.lookup(op); });
OperationConverter opConverter(target, patterns, config,
OpConversionMode::Analysis);
LogicalResult status = opConverter.convertOperations(opsToConvert);
LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
OpConversionMode::Analysis);

// Remap `legalizableOps`, so that they point to the original ops and not the
// cloned ops.
Expand Down
Loading