diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 10cfe851765dc..d6045ab46d21d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -475,6 +475,25 @@ class RewriterBase : public OpBuilder { RewriterBase::Listener *rewriteListener; }; + /// A listener that logs notification events to llvm::dbgs() before + /// forwarding to the base listener. + struct PatternLoggingListener : public RewriterBase::ForwardingListener { + PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName) + : RewriterBase::ForwardingListener(listener), patternName(patternName) { + } + + void notifyOperationInserted(Operation *op, InsertPoint previous) override; + void notifyOperationModified(Operation *op) override; + void notifyOperationReplaced(Operation *op, Operation *newOp) override; + void notifyOperationReplaced(Operation *op, + ValueRange replacement) override; + void notifyOperationErased(Operation *op) override; + void notifyPatternBegin(const Pattern &pattern, Operation *op) override; + + private: + StringRef patternName; + }; + /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller /// is responsible for creating or updating the operation transferring flow diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt index 4cabac185171c..3ef69cea18f0a 100644 --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -29,6 +29,7 @@ add_mlir_library(MLIRIR ODSSupport.cpp Operation.cpp OperationSupport.cpp + PatternLoggingListener.cpp PatternMatch.cpp Region.cpp RegionKindInterface.cpp diff --git a/mlir/lib/IR/PatternLoggingListener.cpp b/mlir/lib/IR/PatternLoggingListener.cpp new file mode 100644 index 0000000000000..735997bbd5937 --- /dev/null +++ b/mlir/lib/IR/PatternLoggingListener.cpp @@ -0,0 +1,55 @@ +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "pattern-logging-listener" + +using namespace mlir; + +static constexpr StringLiteral catalogPrefix = "PatternLoggingListener: "; + +void RewriterBase::PatternLoggingListener::notifyOperationInserted( + Operation *op, InsertPoint previous) { + LLVM_DEBUG(llvm::dbgs() << catalogPrefix << patternName + << " | notifyOperationInserted" + << " | " << op->getName() << "\n"); + ForwardingListener::notifyOperationInserted(op, previous); +} + +void RewriterBase::PatternLoggingListener::notifyOperationModified(Operation *op) { + LLVM_DEBUG(llvm::dbgs() << catalogPrefix << patternName + << " | notifyOperationModified" + << " | " << op->getName() << "\n"); + ForwardingListener::notifyOperationModified(op); +} + +void RewriterBase::PatternLoggingListener::notifyOperationReplaced( + Operation *op, Operation *newOp) { + LLVM_DEBUG(llvm::dbgs() << catalogPrefix << patternName + << " | notifyOperationReplaced (with op)" + << " | " << op->getName() << " | " << newOp->getName() + << "\n"); + ForwardingListener::notifyOperationReplaced(op, newOp); +} + +void RewriterBase::PatternLoggingListener::notifyOperationReplaced( + Operation *op, ValueRange replacement) { + LLVM_DEBUG(llvm::dbgs() << catalogPrefix << patternName + << " | notifyOperationReplaced (with values)" + << " | " << op->getName() << "\n"); + ForwardingListener::notifyOperationReplaced(op, replacement); +} + +void RewriterBase::PatternLoggingListener::notifyOperationErased(Operation *op) { + LLVM_DEBUG(llvm::dbgs() << catalogPrefix << patternName + << " | notifyOperationErased" + << " | " << op->getName() << "\n"); + ForwardingListener::notifyOperationErased(op); +} + +void RewriterBase::PatternLoggingListener::notifyPatternBegin( + const Pattern &pattern, Operation *op) { + LLVM_DEBUG(llvm::dbgs() << catalogPrefix << patternName + << " | notifyPatternBegin" + << " | " << op->getName() << "\n"); + ForwardingListener::notifyPatternBegin(pattern, op); +} diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index 4a12183492fd4..b2b372b7b1249 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -15,6 +15,10 @@ #include "ByteCode.h" #include "llvm/Support/Debug.h" +#ifndef NDEBUG +#include "llvm/ADT/ScopeExit.h" +#endif + #define DEBUG_TYPE "pattern-application" using namespace mlir; @@ -206,11 +210,19 @@ LogicalResult PatternApplicator::matchAndRewrite( } else { LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" << bestPattern->getDebugName() << "\"\n"); - const auto *pattern = static_cast(bestPattern); - result = pattern->matchAndRewrite(op, rewriter); +#ifndef NDEBUG + OpBuilder::Listener *oldListener = rewriter.getListener(); + auto loggingListener = + std::make_unique( + oldListener, pattern->getDebugName()); + rewriter.setListener(loggingListener.get()); + auto resetListenerCallback = llvm::make_scope_exit( + [&] { rewriter.setListener(oldListener); }); +#endif + result = pattern->matchAndRewrite(op, rewriter); LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName() << "\" result " << succeeded(result) << "\n"); diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 9b5cadd62befc..233fef8ec4296 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -301,6 +301,17 @@ def find_real_python_interpreter(): ToolSubst("mlir-opt", "mlir-opt --verify-roundtrip", unresolved="fatal"), ] ) +elif "MLIR_GENERATE_PATTERN_CATALOG" in os.environ: + tools.extend( + [ + ToolSubst( + "mlir-opt", + "mlir-opt --debug-only=pattern-logging-listener --mlir-disable-threading", + unresolved="fatal", + ), + ToolSubst("FileCheck", "FileCheck --dump-input=always", unresolved="fatal"), + ] + ) else: tools.extend(["mlir-opt"])