diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h index 93cd780bba438..0c1203e1e3c0e 100644 --- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h +++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h @@ -9,6 +9,7 @@ #ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H #define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H +#include "mlir/IR/PatternMatch.h" #include namespace mlir { @@ -23,7 +24,8 @@ class Pass; void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool approximateLog1p = true); + bool approximateLog1p = true, + PatternBenefit benefit = 1); void registerConvertMathToLLVMInterface(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h index 1fac96abef621..7120d129952d8 100644 --- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h +++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h @@ -20,7 +20,8 @@ class OperationPass; /// Populate the given list with patterns that convert from Math to Libm calls. /// If log1pBenefit is present, use it instead of benefit for the Log1p op. void populateMathToLibmConversionPatterns( - RewritePatternSet &patterns, const ConvertMathToLibmOptions &options); + RewritePatternSet &patterns, const ConvertMathToLibmOptions &options, + PatternBenefit benefit = 1); /// Create a pass to convert Math operations to libm calls. std::unique_ptr> createConvertMathToLibmPass(); diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 98680773e00d2..85ec288268aeb 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -304,9 +304,9 @@ struct ConvertMathToLLVMPass void mlir::populateMathToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool approximateLog1p) { + bool approximateLog1p, PatternBenefit benefit) { if (approximateLog1p) - patterns.add(converter); + patterns.add(converter, benefit); // clang-format off patterns.add< AbsFOpLowering, @@ -337,7 +337,7 @@ void mlir::populateMathToLLVMConversionPatterns( FTruncOpLowering, TanOpLowering, TanhOpLowering - >(converter); + >(converter, benefit); // clang-format on } diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index e2d929e9fa0e9..ec4b1b9b0162e 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -50,10 +50,10 @@ template struct ScalarOpToLibmCall : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, - StringRef doubleFunc) - : OpRewritePattern(context), floatFunc(floatFunc), - doubleFunc(doubleFunc){}; + ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit, + StringRef floatFunc, StringRef doubleFunc) + : OpRewritePattern(context, benefit), floatFunc(floatFunc), + doubleFunc(doubleFunc) {}; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; @@ -62,10 +62,11 @@ struct ScalarOpToLibmCall : public OpRewritePattern { }; template -void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx, - StringRef floatFunc, StringRef doubleFunc) { - patterns.add, PromoteOpToF32>(ctx); - patterns.add>(ctx, floatFunc, doubleFunc); +void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit, + MLIRContext *ctx, StringRef floatFunc, + StringRef doubleFunc) { + patterns.add, PromoteOpToF32>(ctx, benefit); + patterns.add>(ctx, benefit, floatFunc, doubleFunc); } } // namespace @@ -162,49 +163,49 @@ ScalarOpToLibmCall::matchAndRewrite(Op op, } void mlir::populateMathToLibmConversionPatterns( - RewritePatternSet &patterns, const ConvertMathToLibmOptions &options) { + RewritePatternSet &patterns, const ConvertMathToLibmOptions &options, PatternBenefit benefit) { MLIRContext *ctx = patterns.getContext(); - populatePatternsForOp(patterns, ctx, "fabsf", "fabs"); - populatePatternsForOp(patterns, ctx, "acosf", "acos"); - populatePatternsForOp(patterns, ctx, "acoshf", "acosh"); - populatePatternsForOp(patterns, ctx, "asinf", "asin"); - populatePatternsForOp(patterns, ctx, "asinhf", "asinh"); - populatePatternsForOp(patterns, ctx, "atan2f", "atan2"); - populatePatternsForOp(patterns, ctx, "atanf", "atan"); - populatePatternsForOp(patterns, ctx, "atanhf", "atanh"); - populatePatternsForOp(patterns, ctx, "cbrtf", "cbrt"); - populatePatternsForOp(patterns, ctx, "ceilf", "ceil"); - populatePatternsForOp(patterns, ctx, "cosf", "cos"); - populatePatternsForOp(patterns, ctx, "coshf", "cosh"); - populatePatternsForOp(patterns, ctx, "erff", "erf"); - populatePatternsForOp(patterns, ctx, "expf", "exp"); - populatePatternsForOp(patterns, ctx, "exp2f", "exp2"); - populatePatternsForOp(patterns, ctx, "expm1f", "expm1"); - populatePatternsForOp(patterns, ctx, "floorf", "floor"); - populatePatternsForOp(patterns, ctx, "fmaf", "fma"); - populatePatternsForOp(patterns, ctx, "logf", "log"); - populatePatternsForOp(patterns, ctx, "log2f", "log2"); - populatePatternsForOp(patterns, ctx, "log10f", "log10"); - populatePatternsForOp(patterns, ctx, "log1pf", "log1p"); - populatePatternsForOp(patterns, ctx, "powf", "pow"); + populatePatternsForOp(patterns, benefit, ctx, "fabsf", "fabs"); + populatePatternsForOp(patterns, benefit, ctx, "acosf", "acos"); + populatePatternsForOp(patterns, benefit, ctx, "acoshf", "acosh"); + populatePatternsForOp(patterns, benefit, ctx, "asinf", "asin"); + populatePatternsForOp(patterns, benefit, ctx, "asinhf", "asinh"); + populatePatternsForOp(patterns, benefit, ctx, "atan2f", "atan2"); + populatePatternsForOp(patterns, benefit, ctx, "atanf", "atan"); + populatePatternsForOp(patterns, benefit, ctx, "atanhf", "atanh"); + populatePatternsForOp(patterns, benefit, ctx, "cbrtf", "cbrt"); + populatePatternsForOp(patterns, benefit, ctx, "ceilf", "ceil"); + populatePatternsForOp(patterns, benefit, ctx, "cosf", "cos"); + populatePatternsForOp(patterns, benefit, ctx, "coshf", "cosh"); + populatePatternsForOp(patterns, benefit, ctx, "erff", "erf"); + populatePatternsForOp(patterns, benefit, ctx, "expf", "exp"); + populatePatternsForOp(patterns, benefit, ctx, "exp2f", "exp2"); + populatePatternsForOp(patterns, benefit, ctx, "expm1f", "expm1"); + populatePatternsForOp(patterns, benefit, ctx, "floorf", "floor"); + populatePatternsForOp(patterns, benefit, ctx, "fmaf", "fma"); + populatePatternsForOp(patterns, benefit, ctx, "logf", "log"); + populatePatternsForOp(patterns, benefit, ctx, "log2f", "log2"); + populatePatternsForOp(patterns, benefit, ctx, "log10f", "log10"); + populatePatternsForOp(patterns, benefit, ctx, "log1pf", "log1p"); + populatePatternsForOp(patterns, benefit, ctx, "powf", "pow"); if (options.allowC23Features) - populatePatternsForOp(patterns, ctx, "roundevenf", + populatePatternsForOp(patterns, benefit, ctx, "roundevenf", "roundeven"); else if (options.roundingModeIsDefault) - populatePatternsForOp(patterns, ctx, "nearbyintf", + populatePatternsForOp(patterns, benefit, ctx, "nearbyintf", "nearbyint"); // Roundeven: using nearbyint (pre-C23) for roundeven requires the // rounding mode to be FE_TONEAREST (the default). Otherwise we need to // issue a call to set the rounding mode (which this pass currently can't do). - populatePatternsForOp(patterns, ctx, "roundf", "round"); - populatePatternsForOp(patterns, ctx, "sinf", "sin"); - populatePatternsForOp(patterns, ctx, "sinhf", "sinh"); - populatePatternsForOp(patterns, ctx, "sqrtf", "sqrt"); - populatePatternsForOp(patterns, ctx, "rsqrtf", "rsqrt"); - populatePatternsForOp(patterns, ctx, "tanf", "tan"); - populatePatternsForOp(patterns, ctx, "tanhf", "tanh"); - populatePatternsForOp(patterns, ctx, "truncf", "trunc"); + populatePatternsForOp(patterns, benefit, ctx, "roundf", "round"); + populatePatternsForOp(patterns, benefit, ctx, "sinf", "sin"); + populatePatternsForOp(patterns, benefit, ctx, "sinhf", "sinh"); + populatePatternsForOp(patterns, benefit, ctx, "sqrtf", "sqrt"); + populatePatternsForOp(patterns, benefit, ctx, "rsqrtf", "rsqrt"); + populatePatternsForOp(patterns, benefit, ctx, "tanf", "tan"); + populatePatternsForOp(patterns, benefit, ctx, "tanhf", "tanh"); + populatePatternsForOp(patterns, benefit, ctx, "truncf", "trunc"); } namespace {