From bb1ef9db94cc69f79450d7269b9e14cd4a3dcd94 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 24 Apr 2025 14:47:35 -0700 Subject: [PATCH] [Mosaic] NFC: Refactor Mosaic canonicalization pass to use RewritePattern. Right now, we're manually canonicalizing ops generated in the pattern (See arith::FPToSIOp and vector::ContractionOp). Also we might have an op that has both standalone and elementwise canonicalization pattern (e.g., arith::Select is suitable for elementwise upcast), and the current approach is bug-prone because each rule invalidates op. RewritePattern infra seems a good-fit here. One caveat is that we're also doing (pre-infer) verification in the canonicalization pass while RewritePattern uses mlir::failure() to identify whether the transformation converges. I make it a tri-state instead: - If op is rewritten, return `success()`. - If op is not matched and just passes through, return silent `failure()` without any diagnostic messages. - If op is invalid, return `failure()` with a meaningful diagnostic message. But, we should consider separating verification from canonicalization too. PiperOrigin-RevId: 751146100 --- .../tpu/transforms/canonicalize_mosaic.cc | 488 ++++++++++-------- 1 file changed, 277 insertions(+), 211 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 247f47431745..5e809c972553 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -15,7 +15,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -26,7 +25,6 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -40,6 +38,7 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" @@ -49,7 +48,9 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/util.h" #include "jaxlib/mosaic/dialect/tpu/vreg_util.h" namespace mlir::tpu { @@ -67,9 +68,109 @@ struct CanonicalizeContext { int hardware_generation; }; -LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, - tpu::MatmulOp op) { - ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); +// NOTE: Subclasses must overload the `matchAndRewrite` method and carefully +// return the one of the following: +// +// - If op is rewritten, return `success()`. +// - If op is not matched and just passes through, return silent `failure()` +// without any diagnostic messages. +// - If op is invalid, return `failure()` with a meaningful diagnostic message. +template +class MosaicOpRewritePattern : public OpRewritePattern { + public: + explicit MosaicOpRewritePattern(MLIRContext *context, + const CanonicalizeContext &ctx) + : OpRewritePattern(context), ctx(ctx) {}; + + protected: + CanonicalizeContext ctx; +}; + +class MosaicRewritePattern : public RewritePattern { + public: + explicit MosaicRewritePattern(MLIRContext *context, + const CanonicalizeContext &ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), ctx(ctx) {} + + protected: + CanonicalizeContext ctx; +}; + +class CanonicalizeTpuMatmulOp : public MosaicOpRewritePattern { + public: + using MosaicOpRewritePattern::MosaicOpRewritePattern; + LogicalResult matchAndRewrite(tpu::MatmulOp op, + PatternRewriter &rewriter) const override; + + private: + bool match(tpu::MatmulOp op) const; +}; + +class CanonicalizeVectorMultiDimReductionOp + : public MosaicOpRewritePattern { + public: + using MosaicOpRewritePattern< + vector::MultiDimReductionOp>::MosaicOpRewritePattern; + LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, + PatternRewriter &rewriter) const override; +}; + +class CanonicalizeVectorContractionOp + : public MosaicOpRewritePattern { + public: + using MosaicOpRewritePattern::MosaicOpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; +}; + +class CanonicalizeVectorExtractOp + : public MosaicOpRewritePattern { + public: + using MosaicOpRewritePattern::MosaicOpRewritePattern; + LogicalResult matchAndRewrite(vector::ExtractOp op, + PatternRewriter &rewriter) const override; +}; + +class CanonicalizeArithSelectOp + : public MosaicOpRewritePattern { + public: + using MosaicOpRewritePattern::MosaicOpRewritePattern; + LogicalResult matchAndRewrite(arith::SelectOp op, + PatternRewriter &rewriter) const override; +}; + +class CanonicalizeArithFPToSIOp + : public MosaicOpRewritePattern { + public: + using MosaicOpRewritePattern::MosaicOpRewritePattern; + LogicalResult matchAndRewrite(arith::FPToSIOp op, + PatternRewriter &rewriter) const override; +}; + +class CanonicalizeTpuRepeatOp : public MosaicOpRewritePattern { + public: + using MosaicOpRewritePattern::MosaicOpRewritePattern; + LogicalResult matchAndRewrite(tpu::RepeatOp op, + PatternRewriter &rewriter) const override; +}; + +class BF16UpcastElementwiseOp : public MosaicRewritePattern { + public: + using MosaicRewritePattern::MosaicRewritePattern; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + + private: + bool match(Operation *op) const; +}; + +LogicalResult CanonicalizeTpuMatmulOp::matchAndRewrite( + tpu::MatmulOp op, PatternRewriter &rewriter) const { + if (!match(op)) { + return failure(); + } + + ImplicitLocOpBuilder builder(op.getLoc(), rewriter); auto transpose_lhs = op.getTransposeLhs(); auto transpose_rhs = op.getTransposeRhs(); @@ -270,52 +371,86 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, // Technically almost identical to the case where batch_size is 1, but // we want to avoid the spurious concat here. if (batch_size == 1) { - op.replaceAllUsesWith(outputs[0]); - op.erase(); + rewriter.replaceOp(op, outputs[0]); return success(); } - auto output = builder - .create(op.getLoc(), acc_ty, outputs, - /*dimension=*/0) - .getResult(); - op.replaceAllUsesWith(output); - op.erase(); + rewriter.replaceOpWithNewOp(op, acc_ty, outputs, + /*dimension=*/0); } else { auto matmul_res = dot_dim_matmul(lhs, rhs, acc).getResult(); - op.replaceAllUsesWith(matmul_res); - op.erase(); + rewriter.replaceOp(op, matmul_res); } return success(); }; -LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx, - Operation &op) { - OpBuilder builder(&op); - auto operands = op.getOperands(); - auto res_ty = dyn_cast(op.getResult(0).getType()); - if (op.getNumResults() != 1) { - op.emitOpError("Invariant violated: Unexpected number of results"); +bool CanonicalizeTpuMatmulOp::match(tpu::MatmulOp op) const { + // Downstream Mosaic passes can only handle + // - 2D matmuls + // - All integers or all floats on operands + // - Non-transposed LHS + // Canonicalize it if any of these violates. + auto lhs_ty = op.getLhs().getType(); + auto rhs_ty = op.getRhs().getType(); + auto acc_ty = op.getAcc().getType(); + if (lhs_ty.getShape().size() != 2 || rhs_ty.getShape().size() != 2 || + acc_ty.getShape().size() != 2) { + return true; + } + auto lhs_element_type = lhs_ty.getElementType(); + auto rhs_element_type = rhs_ty.getElementType(); + auto acc_element_type = acc_ty.getElementType(); + if (!(acc_element_type.isInteger() && lhs_element_type.isInteger() && + rhs_element_type.isInteger()) && + !(acc_element_type.isFloat() && lhs_element_type.isFloat() && + rhs_element_type.isFloat())) { + return true; + } + if (op.getTransposeLhs()) { + return true; + } + + ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); + if (!op.getDimensionNumbers().has_value()) { + return true; + } + auto maybe_transposed = isTransposedMatmul(op.getDimensionNumbersAttr()); + if (maybe_transposed.has_value() && maybe_transposed->first) { + return true; + } + return false; +} + +LogicalResult BF16UpcastElementwiseOp::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + if (!match(op)) { + return failure(); + } + + auto operands = op->getOperands(); + auto res_ty = dyn_cast(op->getResult(0).getType()); + if (op->getNumResults() != 1) { + op->emitOpError("Invariant violated: Unexpected number of results"); return failure(); } if (!res_ty) { // scalar // TODO(mvoz): Add canonicalization and invariants for scalar elementwise // ops. - return success(); + return failure(); } auto shape = res_ty.getShape(); std::vector new_operands; new_operands.reserve(operands.size()); bool should_rewrite_op = false; - auto target_f32_ty = VectorType::get(shape, builder.getF32Type()); + auto target_f32_ty = VectorType::get(shape, rewriter.getF32Type()); for (int i = 0; i < operands.size(); ++i) { auto operand = operands[i]; auto ty = dyn_cast(operand.getType()); if (ty) { if (ty.getShape() != shape) { // Should already be checked my MLIR verification, but let's be safe. - op.emitOpError("Mismatched shapes in elementwise op."); + op->emitOpError("Mismatched shapes in elementwise op."); return failure(); } auto element_type = ty.getElementType(); @@ -329,12 +464,13 @@ LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx, if (needs_cast && element_type.isBF16()) { if (ctx.compatibility_mode) { auto target_f32 = - builder.create(op.getLoc(), target_f32_ty, operand) + rewriter + .create(op->getLoc(), target_f32_ty, operand) .getResult(); should_rewrite_op = true; new_operands.push_back(target_f32); } else { - op.emitOpError( + op->emitOpError( "Compatibility mode disabled. Unsupported element type in " "elementwise op on hardware generation: ") << ctx.hardware_generation @@ -346,36 +482,48 @@ LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx, } } else { // Should already be checked my MLIR verification, but let's be safe. - op.emitOpError("MLIR unsupported - mix scalar and vec elementwise ops"); + op->emitOpError("MLIR unsupported - mix scalar and vec elementwise ops"); return failure(); } } if (should_rewrite_op) { - auto result_ty = dyn_cast(op.getResult(0).getType()); + auto result_ty = dyn_cast(op->getResult(0).getType()); if (!result_ty) { - op.emitOpError("Not implemented: Unexpected result type"); + op->emitOpError("Not implemented: Unexpected result type"); return failure(); } auto result_element_type = result_ty.getElementType(); if (!result_element_type.isF32() && !result_element_type.isBF16()) { - op.emitOpError("Not implemented: Unexpected result element type"); + op->emitOpError("Not implemented: Unexpected result element type"); return failure(); } // Do the new op in f32, then truncate to the original element type. - auto new_op = builder.create(op.getLoc(), op.getName().getIdentifier(), - new_operands, target_f32_ty); - new_op = builder.create(op.getLoc(), res_ty, - new_op->getResult(0)); - op.replaceAllUsesWith(new_op); - op.erase(); + auto new_op = rewriter.create(op->getLoc(), op->getName().getIdentifier(), + new_operands, target_f32_ty); + rewriter.replaceOpWithNewOp(op, res_ty, + new_op->getResult(0)); + return success(); } - return success(); + return failure(); } -LogicalResult canonicalize_multi_dim_reduction(const CanonicalizeContext &ctx, - Operation &operation) { - ImplicitLocOpBuilder builder(operation.getLoc(), &operation); - auto op = cast(operation); +bool BF16UpcastElementwiseOp::match(Operation *op) const { + if (isa(op)) { + auto vec_ty = dyn_cast(op->getOperand(0).getType()); + if (vec_ty && vec_ty.getElementType().isBF16() && + ctx.hardware_generation >= 4) { + return false; + } + return true; + } + return isa(op); +} + +LogicalResult CanonicalizeVectorMultiDimReductionOp::matchAndRewrite( + vector::MultiDimReductionOp op, PatternRewriter &rewriter) const { + ImplicitLocOpBuilder builder(op.getLoc(), rewriter); auto source_ty = op.getSourceVectorType(); auto result_ty = dyn_cast(op.getDestType()); if (!result_ty) { @@ -384,7 +532,7 @@ LogicalResult canonicalize_multi_dim_reduction(const CanonicalizeContext &ctx, auto element_type = source_ty.getElementType(); if (element_type.isF32()) { - return success(); + return failure(); } else if (element_type.isBF16()) { bool reduces_sublanes = false; for (int64_t dim : op.getReductionDims()) { @@ -406,7 +554,7 @@ LogicalResult canonicalize_multi_dim_reduction(const CanonicalizeContext &ctx, auto result = acc_ext.fold(arith::ExtFOp::FoldAdaptor(const_acc.getValue())); if (!result.isNull() && result.is()) { - acc_ext->erase(); + rewriter.eraseOp(acc_ext); new_acc = builder.create( op.getLoc(), result_ty_f32, cast(result.get())); @@ -415,59 +563,42 @@ LogicalResult canonicalize_multi_dim_reduction(const CanonicalizeContext &ctx, auto new_op = builder.create( op.getLoc(), new_acc.getType(), op.getKindAttr(), new_source, new_acc, DenseI64ArrayAttr::get(builder.getContext(), op.getReductionDims())); - auto new_result = builder.create(op.getLoc(), result_ty, - new_op.getResult()); - op.replaceAllUsesWith(new_result.getResult()); - op.erase(); + rewriter.replaceOpWithNewOp(op, result_ty, + new_op.getResult()); + return success(); } - return success(); + return failure(); } else if (element_type.isSignlessInteger(32) && // TODO(b/384774084): Add support for u32 reductions. (op.getKind() == vector::CombiningKind::ADD || op.getKind() == vector::CombiningKind::MAXSI || op.getKind() == vector::CombiningKind::MINSI)) { - return success(); + return failure(); } op.emitOpError("Unsupported element type for the selected reduction"); return failure(); } -LogicalResult canonicalize_matmul(const CanonicalizeContext &ctx, - Operation &op) { - auto matmul_op = dyn_cast(op); - if (!matmul_op) { - op.emitOpError("Invariant violated: Not a matmul"); - return failure(); - } - return tpu_matmul_rule(ctx, matmul_op); -}; - -LogicalResult canonicalize_contraction(const CanonicalizeContext &ctx, - Operation &op) { - auto contraction_op = dyn_cast(op); - if (!contraction_op) { - op.emitOpError("Invariant violated: Not a contraction"); - return failure(); - } +LogicalResult CanonicalizeVectorContractionOp::matchAndRewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { // Rewrite the contraction as a matmul - auto lhs = contraction_op.getLhs(); - auto rhs = contraction_op.getRhs(); - auto acc = contraction_op.getAcc(); + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + auto acc = op.getAcc(); VectorType acc_ty; if (!(acc_ty = dyn_cast(acc.getType()))) { - contraction_op->emitOpError("Not implemented: acc must be a vector"); + op->emitOpError("Not implemented: acc must be a vector"); return failure(); } - if (contraction_op.getKind() != vector::CombiningKind::ADD) { - contraction_op->emitOpError("Only ADD supported"); + if (op.getKind() != vector::CombiningKind::ADD) { + op->emitOpError("Only ADD supported"); return failure(); } - ImplicitLocOpBuilder builder(contraction_op->getLoc(), - contraction_op.getOperation()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - MLIRContext *const mlir_ctx = contraction_op->getContext(); + MLIRContext *const mlir_ctx = op->getContext(); auto getMapAttr = [&](const unsigned first, const unsigned second) { return AffineMapAttr::get(AffineMap::get( @@ -480,10 +611,10 @@ LogicalResult canonicalize_contraction(const CanonicalizeContext &ctx, {getMapAttr(0, 2), getMapAttr(2, 1), getMapAttr(0, 1)}); const ArrayAttr matmul_indexing_maps_transposed = builder.getArrayAttr( {getMapAttr(0, 2), getMapAttr(1, 2), getMapAttr(0, 1)}); - const auto indexing_maps = contraction_op.getIndexingMaps(); + const auto indexing_maps = op.getIndexingMaps(); if (indexing_maps != matmul_indexing_maps && indexing_maps != matmul_indexing_maps_transposed) { - return contraction_op->emitOpError( + return op->emitOpError( "Not implemented: Non-matmul or unsupported indexing_maps"); } const bool transpose_rhs = indexing_maps == matmul_indexing_maps_transposed; @@ -495,29 +626,24 @@ LogicalResult canonicalize_contraction(const CanonicalizeContext &ctx, vector::IteratorType::parallel), builder.getAttr( vector::IteratorType::reduction)}); - if (contraction_op->getAttr("iterator_types") != matmul_iterator_types) { - return contraction_op->emitOpError( - "Not implemented: Non-matmul iterator_types"); + if (op->getAttr("iterator_types") != matmul_iterator_types) { + return op->emitOpError("Not implemented: Non-matmul iterator_types"); } const tpu::ContractPrecisionAttr precision_attr = // May be null - contraction_op->getAttrOfType("precision"); + op->getAttrOfType("precision"); const auto dot_dimension_numbers_attr = defaultDimensionNumbers(builder, false, transpose_rhs); - auto matmul_op = builder.create( - contraction_op->getLoc(), acc_ty, lhs, rhs, acc, + rewriter.replaceOpWithNewOp( + op, acc_ty, lhs, rhs, acc, /*transpose_lhs=*/false, /*transpose_rhs=*/false, precision_attr, dot_dimension_numbers_attr); - contraction_op.replaceAllUsesWith(matmul_op.getResult()); - contraction_op.erase(); - auto result = tpu_matmul_rule(ctx, matmul_op); - return result; + return success(); } -LogicalResult canonicalize_extract(const CanonicalizeContext &ctx, - Operation &raw_op) { - auto op = dyn_cast(raw_op); +LogicalResult CanonicalizeVectorExtractOp::matchAndRewrite( + vector::ExtractOp op, PatternRewriter &rewriter) const { Type result_ty = op.getResult().getType(); if (!isa(result_ty)) { bool is_supported = result_ty.isSignlessIntOrFloat() && @@ -528,33 +654,29 @@ LogicalResult canonicalize_extract(const CanonicalizeContext &ctx, "32-bit type first."); } } - return success(); + return failure(); } -LogicalResult canonicalize_select(const CanonicalizeContext &ctx, - Operation &raw_op) { - auto op = dyn_cast(raw_op); +LogicalResult CanonicalizeArithSelectOp::matchAndRewrite( + arith::SelectOp op, PatternRewriter &rewriter) const { if (!isa(op.getType()) || isa(op.getCondition().getType())) { - return success(); + return failure(); } // Canonicalize `i1 ? v1 : v2` -> `broadcast(i1) ? v1 : v2`. - ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto cond_ty = VectorType::get(cast(op.getType()).getShape(), op.getCondition().getType()); auto cond = builder.create(cond_ty, op.getCondition()); - auto new_op = builder.create( - op.getLoc(), cond, op.getTrueValue(), op.getFalseValue()); - op.replaceAllUsesWith(new_op.getResult()); - op.erase(); + rewriter.replaceOpWithNewOp(op, cond, op.getTrueValue(), + op.getFalseValue()); return success(); } // All conversions that change bitwidth must be canonicalized to tpu.fptosi. -LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, - Operation &raw_op) { - auto op = cast(raw_op); - ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); +LogicalResult CanonicalizeArithFPToSIOp::matchAndRewrite( + arith::FPToSIOp op, PatternRewriter &rewriter) const { + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto src_vty = dyn_cast(op.getIn().getType()); auto dst_vty = dyn_cast(op.getType()); if (static_cast(src_vty) != static_cast(dst_vty)) { @@ -569,6 +691,9 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth(); dst_bitwidth = op.getType().getIntOrFloatBitWidth(); } + if (src_bitwidth == 32 && dst_bitwidth == 32) { + return failure(); + } if (dst_bitwidth > 32) { return op.emitOpError("Target bitwidth too large"); } @@ -577,19 +702,8 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, src_vty.getElementType().isBF16() && (dst_vty.getElementType().isSignlessInteger(8) || dst_vty.getElementType().isSignlessInteger(4))) { - auto new_op = builder.create( - op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero); - op.replaceAllUsesWith(new_op.getResult()); - op.erase(); - // We briefly trigger canonicalization here to potentially fuse the rounding - // ops into the newly created tpu.fptosi. - { - PatternRewriter rewriter(new_op.getContext()); - rewriter.setInsertionPoint(new_op); - // We don't care if the canonicalization pattern matched or not. - (void)tpu::FPToSIOp::canonicalize(new_op, rewriter); - new_op = nullptr; // Canonicalization may have erased the op! - } + rewriter.replaceOpWithNewOp(op, op.getType(), op.getIn(), + tpu::RoundingMode::kTowardsZero); return success(); } Value x = op.getIn(); @@ -643,14 +757,12 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, } x = builder.create(op.getType(), x); } - op.replaceAllUsesWith(x); - op.erase(); + rewriter.replaceOp(op, x); return success(); } -LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx, - Operation &raw_op) { - auto op = dyn_cast(raw_op); +LogicalResult CanonicalizeTpuRepeatOp::matchAndRewrite( + tpu::RepeatOp op, PatternRewriter &rewriter) const { if (!isa(op.getType())) { return op.emitOpError("Only vector types supported"); } @@ -659,100 +771,17 @@ LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx, if (times == 1) { // A true no op - kind of an odd edge case, but this does come up in // flash_attention_backward tests. - op.replaceAllUsesWith(operand); - op.erase(); + rewriter.replaceOp(op, operand); return success(); } auto operands = std::vector(times, operand); - ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto concat = builder.create(op.getLoc(), op.getType(), operands, op.getDimension()); - op.replaceAllUsesWith(concat.getResult()); - op.erase(); + rewriter.replaceOp(op, concat.getResult()); return success(); } -using canonicalize_rule_type = - std::function; - -const llvm::StringMap &rules() { - static auto rules = new llvm::StringMap{ - {tpu::MatmulOp::getOperationName(), canonicalize_matmul}, - {vector::ContractionOp::getOperationName(), canonicalize_contraction}, - {vector::ExtractOp::getOperationName(), canonicalize_extract}, - {vector::MultiDimReductionOp::getOperationName(), - canonicalize_multi_dim_reduction}, - {arith::SelectOp::getOperationName(), canonicalize_select}, - {arith::FPToSIOp::getOperationName(), canonicalize_fptosi}, - {tpu::RepeatOp::getOperationName(), canonicalize_repeat}}; - return *rules; -} - -bool need_elementwise_canonicalization(CanonicalizeContext ctx, Operation &op) { - if (isa(op)) { - auto vec_ty = dyn_cast(op.getOperand(0).getType()); - if (vec_ty && vec_ty.getElementType().isBF16() && - ctx.hardware_generation >= 4) { - return false; - } - return true; - } - return isa(op); -} - -class MosaicCanonicalizer { - public: - MosaicCanonicalizer(int hardware_generation, bool compatibility_mode) - : hardware_generation_(hardware_generation), - compatibility_mode_(compatibility_mode) {} - - int hardware_generation_; - bool compatibility_mode_; - - LogicalResult canonicalize(func::FuncOp op) { - if (!op.getBody().hasOneBlock()) { - op.emitOpError("Only one block functions supported"); - return failure(); - } - return canonicalizeBlock(op.getBody().front()); - } - - LogicalResult canonicalizeBlock(Block &block) { - // make_early_inc_range is utilized due to op mutation. - for (Operation &any_op : make_early_inc_range(block)) { - if (canonicalizeOp(any_op).failed()) { - return failure(); - } - } - return success(); - } - - LogicalResult canonicalizeOp(Operation &any_op) { - CanonicalizeContext ctx({compatibility_mode_, hardware_generation_}); - // We must iterate over the op first, because canonicalization can cause - // us to .erase() an op, and accessing getRegions on it after is not sound. - // Invariant - top level ops with regions may never be invalidated. - for (Region ®ion : any_op.getRegions()) { - for (Block &block : region) { - if (canonicalizeBlock(block).failed()) { - return failure(); - } - } - } - if (need_elementwise_canonicalization(ctx, any_op)) { - return canonicalize_elementwise(ctx, any_op); - } - if (auto rule_it = rules().find(any_op.getName().getStringRef()); - rule_it != rules().end()) { - const canonicalize_rule_type &rule = rule_it->getValue(); - return rule(ctx, any_op); - } - return success(); - } -}; - struct CanonicalizeMosaicPass : public impl::CanonicalizeMosaicPassBase { CanonicalizeMosaicPass(int hardware_generation_p, bool compatibility_mode_p) @@ -760,10 +789,47 @@ struct CanonicalizeMosaicPass this->hardware_generation = hardware_generation_p; } + RewritePatternSet getCanonicalizePatterns() { + RewritePatternSet patterns(&getContext()); + CanonicalizeContext ctx({compatibility_mode_, hardware_generation}); + patterns.add(&getContext(), ctx); + patterns.add(&getContext(), ctx); + tpu::FPToSIOp::getCanonicalizationPatterns(patterns, &getContext()); + return patterns; + } + void runOnOperation() override { + RewritePatternSet patterns = getCanonicalizePatterns(); + GreedyRewriteConfig config = { + .useTopDownTraversal = true, + .enableRegionSimplification = GreedySimplifyRegionLevel::Disabled, + .maxIterations = 3, + .strictMode = GreedyRewriteStrictness::ExistingAndNewOps, + .fold = false, + .cseConstants = false, + }; + func::FuncOp func = getOperation(); - MosaicCanonicalizer vlc(hardware_generation, compatibility_mode_); - if (vlc.canonicalize(func).failed()) { + std::optional diag; + auto handler = std::make_optional( + func.getContext(), [&](Diagnostic &d) { + if (!d.str().empty()) { + diag = std::move(d); + } + }); + // If canonicalization does not converge OR there is a unrecoverable + // failure, signal a pass failure. + if (failed(applyPatternsGreedily(func.getBody(), std::move(patterns), + config)) || + diag.has_value()) { + if (diag.has_value()) { + // Clean up RAII, otherwise it will eat the last diagnostic. + handler.reset(); + emitError(diag->getLocation()) << diag->str(); + } signalPassFailure(); } };