From 5c93eb56dc9bc0c0210483cdd5d31e6b6580454f Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 14 Feb 2025 22:38:11 -0600 Subject: [PATCH] [MLIR][Math] Add optional benefit arg to populate math lowering patterns (#127291) Co-authored-by: Ivan R. Ivanov --- .../mlir/Conversion/MathToLLVM/MathToLLVM.h | 4 +- .../mlir/Conversion/MathToLibm/MathToLibm.h | 3 +- mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 6 +- mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 95 +++++++++++-------- 4 files changed, 62 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h index 93cd780bba438..0c1203e1e3c0e 100644 --- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h +++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h @@ -9,6 +9,7 @@ #ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H #define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H +#include "mlir/IR/PatternMatch.h" #include namespace mlir { @@ -23,7 +24,8 @@ class Pass; void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool approximateLog1p = true); + bool approximateLog1p = true, + PatternBenefit benefit = 1); void registerConvertMathToLLVMInterface(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h index ab9a1cef20cab..8ace53a0fd582 100644 --- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h +++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h @@ -19,7 +19,8 @@ class OperationPass; /// Populate the given list with patterns that convert from Math to Libm calls. /// If log1pBenefit is present, use it instead of benefit for the Log1p op. -void populateMathToLibmConversionPatterns(RewritePatternSet &patterns); +void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Create a pass to convert Math operations to libm calls. std::unique_ptr> createConvertMathToLibmPass(); diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 98680773e00d2..85ec288268aeb 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -304,9 +304,9 @@ struct ConvertMathToLLVMPass void mlir::populateMathToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool approximateLog1p) { + bool approximateLog1p, PatternBenefit benefit) { if (approximateLog1p) - patterns.add(converter); + patterns.add(converter, benefit); // clang-format off patterns.add< AbsFOpLowering, @@ -337,7 +337,7 @@ void mlir::populateMathToLLVMConversionPatterns( FTruncOpLowering, TanOpLowering, TanhOpLowering - >(converter); + >(converter, benefit); // clang-format on } diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index a2488dc600f51..12a6d9c3452df 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -50,10 +50,10 @@ template struct ScalarOpToLibmCall : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, - StringRef doubleFunc) - : OpRewritePattern(context), floatFunc(floatFunc), - doubleFunc(doubleFunc){}; + ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit, + StringRef floatFunc, StringRef doubleFunc) + : OpRewritePattern(context, benefit), floatFunc(floatFunc), + doubleFunc(doubleFunc) {}; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; @@ -62,10 +62,11 @@ struct ScalarOpToLibmCall : public OpRewritePattern { }; template -void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx, - StringRef floatFunc, StringRef doubleFunc) { - patterns.add, PromoteOpToF32>(ctx); - patterns.add>(ctx, floatFunc, doubleFunc); +void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit, + MLIRContext *ctx, StringRef floatFunc, + StringRef doubleFunc) { + patterns.add, PromoteOpToF32>(ctx, benefit); + patterns.add>(ctx, benefit, floatFunc, doubleFunc); } } // namespace @@ -159,42 +160,54 @@ ScalarOpToLibmCall::matchAndRewrite(Op op, return success(); } -void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) { +void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { MLIRContext *ctx = patterns.getContext(); - populatePatternsForOp(patterns, ctx, "fabsf", "fabs"); - populatePatternsForOp(patterns, ctx, "acosf", "acos"); - populatePatternsForOp(patterns, ctx, "acoshf", "acosh"); - populatePatternsForOp(patterns, ctx, "asinf", "asin"); - populatePatternsForOp(patterns, ctx, "asinhf", "asinh"); - populatePatternsForOp(patterns, ctx, "atan2f", "atan2"); - populatePatternsForOp(patterns, ctx, "atanf", "atan"); - populatePatternsForOp(patterns, ctx, "atanhf", "atanh"); - populatePatternsForOp(patterns, ctx, "cbrtf", "cbrt"); - populatePatternsForOp(patterns, ctx, "ceilf", "ceil"); - populatePatternsForOp(patterns, ctx, "cosf", "cos"); - populatePatternsForOp(patterns, ctx, "coshf", "cosh"); - populatePatternsForOp(patterns, ctx, "erff", "erf"); - populatePatternsForOp(patterns, ctx, "expf", "exp"); - populatePatternsForOp(patterns, ctx, "exp2f", "exp2"); - populatePatternsForOp(patterns, ctx, "expm1f", "expm1"); - populatePatternsForOp(patterns, ctx, "floorf", "floor"); - populatePatternsForOp(patterns, ctx, "fmaf", "fma"); - populatePatternsForOp(patterns, ctx, "logf", "log"); - populatePatternsForOp(patterns, ctx, "log2f", "log2"); - populatePatternsForOp(patterns, ctx, "log10f", "log10"); - populatePatternsForOp(patterns, ctx, "log1pf", "log1p"); - populatePatternsForOp(patterns, ctx, "powf", "pow"); - populatePatternsForOp(patterns, ctx, "roundevenf", + populatePatternsForOp(patterns, benefit, ctx, "fabsf", "fabs"); + populatePatternsForOp(patterns, benefit, ctx, "acosf", "acos"); + populatePatternsForOp(patterns, benefit, ctx, "acoshf", + "acosh"); + populatePatternsForOp(patterns, benefit, ctx, "asinf", "asin"); + populatePatternsForOp(patterns, benefit, ctx, "asinhf", + "asinh"); + populatePatternsForOp(patterns, benefit, ctx, "atan2f", + "atan2"); + populatePatternsForOp(patterns, benefit, ctx, "atanf", "atan"); + populatePatternsForOp(patterns, benefit, ctx, "atanhf", + "atanh"); + populatePatternsForOp(patterns, benefit, ctx, "cbrtf", "cbrt"); + populatePatternsForOp(patterns, benefit, ctx, "ceilf", "ceil"); + populatePatternsForOp(patterns, benefit, ctx, "cosf", "cos"); + populatePatternsForOp(patterns, benefit, ctx, "coshf", "cosh"); + populatePatternsForOp(patterns, benefit, ctx, "erff", "erf"); + populatePatternsForOp(patterns, benefit, ctx, "expf", "exp"); + populatePatternsForOp(patterns, benefit, ctx, "exp2f", "exp2"); + populatePatternsForOp(patterns, benefit, ctx, "expm1f", + "expm1"); + populatePatternsForOp(patterns, benefit, ctx, "floorf", + "floor"); + populatePatternsForOp(patterns, benefit, ctx, "fmaf", "fma"); + populatePatternsForOp(patterns, benefit, ctx, "logf", "log"); + populatePatternsForOp(patterns, benefit, ctx, "log2f", "log2"); + populatePatternsForOp(patterns, benefit, ctx, "log10f", + "log10"); + populatePatternsForOp(patterns, benefit, ctx, "log1pf", + "log1p"); + populatePatternsForOp(patterns, benefit, ctx, "powf", "pow"); + populatePatternsForOp(patterns, benefit, ctx, "roundevenf", "roundeven"); - populatePatternsForOp(patterns, ctx, "roundf", "round"); - populatePatternsForOp(patterns, ctx, "sinf", "sin"); - populatePatternsForOp(patterns, ctx, "sinhf", "sinh"); - populatePatternsForOp(patterns, ctx, "sqrtf", "sqrt"); - populatePatternsForOp(patterns, ctx, "rsqrtf", "rsqrt"); - populatePatternsForOp(patterns, ctx, "tanf", "tan"); - populatePatternsForOp(patterns, ctx, "tanhf", "tanh"); - populatePatternsForOp(patterns, ctx, "truncf", "trunc"); + populatePatternsForOp(patterns, benefit, ctx, "roundf", + "round"); + populatePatternsForOp(patterns, benefit, ctx, "sinf", "sin"); + populatePatternsForOp(patterns, benefit, ctx, "sinhf", "sinh"); + populatePatternsForOp(patterns, benefit, ctx, "sqrtf", "sqrt"); + populatePatternsForOp(patterns, benefit, ctx, "rsqrtf", + "rsqrt"); + populatePatternsForOp(patterns, benefit, ctx, "tanf", "tan"); + populatePatternsForOp(patterns, benefit, ctx, "tanhf", "tanh"); + populatePatternsForOp(patterns, benefit, ctx, "truncf", + "trunc"); } namespace {