diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index 763146aac15b9..431608c1f71c8 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -60,10 +60,29 @@ class GreedyRewriteConfig { static constexpr int64_t kNoLimit = -1; - /// Only ops within the scope are added to the worklist. If no scope is - /// specified, the closest enclosing region around the initial list of ops - /// (or the specified region, depending on which greedy rewrite entry point - /// is used) is used as a scope. + /// Only ops within the scope are allowed to be modified and are added to the + /// worklist. + /// + /// If out-of-scope IR is modified, an assertion will fail inside the greedy + /// pattern rewrite driver if expensive checks are enabled (as long as rewrite + /// patterns use the rewriter API correctly). We also allow attribute + /// modifications of the op that owns the scope region. (This is consistent + /// with the fact that passes are allowed to modify attributes of the + /// operation that they operate on.) + /// + /// The scope region must be isolated from above. This ensures that + /// out-of-scope ops are not affected by rewrites. + /// + /// If no scope is specified, it is set as follows: + /// * Single op greedy rewrite: a greedy rewrite is performed for every region + /// of the op. (See below.) The scope is set to the respective region of + /// each greedy write. + /// * Multi op greedy rewrite: the closest enclosing IsolatedFromAbove region + /// around the initial list of ops. If there is no such region, the scope + /// is `nullptr`. This is because multi-op greedy rewrites are allowed to + /// modify top-level ops. (They are not allowed to erase top-level ops.) + /// * Single region greedy rewrite: the specified region. (The op that owns + /// the region must be isolated from above.) Region *scope = nullptr; /// Strict mode can restrict the ops that are added to the worklist during @@ -124,11 +143,9 @@ applyPatternsAndFoldGreedily(Region ®ion, /// This overload runs a separate greedy rewrite for each region of the /// specified op. A region scope can be set in the configuration parameter. By /// default, the scope is set to the region of the current greedy rewrite. Only -/// in-scope ops are added to the worklist and only in-scope ops and the -/// specified op itself are allowed to be modified by the patterns. -/// -/// Note: The specified op may be modified, but it may not be removed by the -/// patterns. +/// in-scope ops are added to the worklist and only in-scope ops are allowed to +/// be modified by the patterns. In addition, the attributes of the op that +/// owns the scope region may also be modified. /// /// Returns "success" if the iterative process converged (i.e., fixpoint was /// reached) and no more patterns can be matched within the region. `changed` diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 67c2d9d59f4c9..c9a49094fac3d 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -324,6 +324,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter, llvm::SmallDenseSet strictModeFilteredOps; private: +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + /// Return "true" if the given op is guaranteed to be out of scope. + bool isOutOfScope(Operation *op) const; +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + /// Look over the provided operands for any defining operations that should /// be re-added to the worklist. This function should be called when an /// operation is modified or removed, as it may trigger further @@ -375,6 +380,28 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS } +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +bool GreedyPatternRewriteDriver::isOutOfScope(Operation *op) const { + // No op is out of scope if no scope was set. + if (!config.scope) + return false; + // Check if the given op and the scope region are part of the same IR tree. + // The parent op into which the given op was inserted may be unlinked, in + // which case we do not consider the given op to be out of scope. (That parent + // op will likely be inserted later, together with all its nested ops.) + Region *r = config.scope; + while (r) { + if (r->findAncestorOpInRegion(*op) || r->getParentOp() == op) + break; + r = r->getParentRegion(); + } + if (!r) + return false; + // Op is out of scope if it is not within the scope region. + return !config.scope->findAncestorOpInRegion(*op); +} +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + bool GreedyPatternRewriteDriver::processWorklist() { #ifndef NDEBUG const char *logLineComment = @@ -579,6 +606,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { addSingleOpToWorklist(op); return; } + // TODO: Unlinked ops are currently not added to the worklist if a `scope` + // is specified. if (region == nullptr) return; } while ((op = region->getParentOp())); @@ -600,6 +629,13 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) { logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); + +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (config.scope && isOutOfScope(op)) + llvm::report_fatal_error( + "greedy pattern rewrite inserted op into region that is out of scope"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (config.listener) config.listener->notifyOperationInserted(op); if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) @@ -608,10 +644,24 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) { } void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) { + // TODO: This notification should also be triggered when moving an op into + // this op. LLVM_DEBUG({ logger.startLine() << "** Modified: '" << op->getName() << "'(" << op << ")\n"; }); + +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (config.scope) { + // Modifying attributes of the op that owns the scope region is allowed + // when using the applyPatternsAndFoldGreedily(Operation *) entry point. + if (op != config.scope->getParentOp() && isOutOfScope(op)) { + llvm::report_fatal_error("greedy pattern rewrite modified op within " + "region that is out of scope"); + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (config.listener) config.listener->notifyOperationModified(op); addToWorklist(op); @@ -637,16 +687,11 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { << ")\n"; }); -#ifndef NDEBUG - // Only ops that are within the configured scope are added to the worklist of - // the greedy pattern rewriter. Moreover, the parent op of the scope region is - // the part of the IR that is taken into account for the "expensive checks". - // A greedy pattern rewrite is not allowed to erase the parent op of the scope - // region, as that would break the worklist handling and the expensive checks. - if (config.scope && config.scope->getParentOp() == op) - llvm_unreachable( - "scope region must not be erased during greedy pattern rewrite"); -#endif // NDEBUG +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (config.scope && isOutOfScope(op)) + llvm::report_fatal_error( + "greedy pattern rewrite removed op from region that is out of scope"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (config.listener) config.listener->notifyOperationRemoved(op); @@ -800,16 +845,22 @@ LogicalResult mlir::applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config, bool *changed) { - // The top-level operation must be known to be isolated from above to - // prevent performing canonicalizations on operations defined at or above - // the region containing 'op'. - assert(region.getParentOp()->hasTrait() && - "patterns can only be applied to operations IsolatedFromAbove"); - // Set scope if not specified. if (!config.scope) config.scope = ®ion; + // Make sure that the specified region on which the greedy rewrite should + // operate is in scope. + assert(config.scope->isAncestor(®ion) && "input region must be in scope"); + + // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that + // are out of scope are never added to the worklist and any out-of-scope IR + // modifications trigger an assertion when expensive expensive checks are + // enabled (as long as the rewriter API is used correctly). + assert( + config.scope->getParentOp()->hasTrait() && + "greedy pattern rewrite scope must be IsolatedFromAbove"); + #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (failed(verify(config.scope->getParentOp()))) llvm::report_fatal_error( @@ -886,7 +937,8 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef ops, return success(worklist.empty()); } -/// Find the region that is the closest common ancestor of all given ops. +/// Find the IsolateFromAbove region that is the closest common ancestor of all +/// given ops. /// /// Note: This function returns `nullptr` if there is a top-level op among the /// given list of ops. @@ -896,6 +948,7 @@ static Region *findCommonAncestor(ArrayRef ops) { if (ops.size() == 1) return ops.front()->getParentRegion(); + // Find the closest region that contains all ops. Region *region = ops.front()->getParentRegion(); ops = ops.drop_front(); int sz = ops.size(); @@ -912,6 +965,12 @@ static Region *findCommonAncestor(ArrayRef ops) { break; region = region->getParentRegion(); } + + // Find the closest IsolatedFromAbove region. + while (region && + !region->getParentOp()->hasTrait()) + region = region->getParentRegion(); + return region; } @@ -932,8 +991,16 @@ LogicalResult mlir::applyOpPatternsAndFold( // there is a top-level op among `ops`. config.scope = findCommonAncestor(ops); } else { - // If a scope was provided, make sure that all ops are in scope. + // If a scope was provided, make sure that it is IsolatedFromAbove and that + // all ops are in scope. #ifndef NDEBUG + // The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that + // are out of scope are never added to the worklist and any out-of-scope IR + // modifications trigger an assertion when expensive expensive checks are + // enabled (as long as the rewriter API is used correctly). + assert( + config.scope->getParentOp()->hasTrait() && + "greedy pattern rewrite scope must be IsolatedFromAbove"); bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) { return static_cast(config.scope->findAncestorOpInRegion(*op)); });