From de4891839acb8b8d1928cd37ba9c6d00db4a2802 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 22 Sep 2021 23:13:27 +0200 Subject: [PATCH 1/4] Add more ChainRules definitions --- src/chainrules.jl | 33 +++++++++++++++++++++++++--- test/chainrules.jl | 54 +++++++++++++++++++++++++++------------------- 2 files changed, 62 insertions(+), 25 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 303ad87c..b63c6eb6 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -16,7 +16,9 @@ https://github.com/JuliaMath/SpecialFunctions.jl/issues/321 """ ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x)) +ChainRulesCore.@scalar_rule(airyaix(x), airyaiprimex(x) + sqrt(x) * Ω) ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x)) +ChainRulesCore.@scalar_rule(airyaiprimex(x), x * airyaix(x) + sqrt(x) * Ω) ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x)) ChainRulesCore.@scalar_rule(airybiprime(x), x * airybi(x)) ChainRulesCore.@scalar_rule(besselj0(x), -besselj1(x)) @@ -31,12 +33,15 @@ ChainRulesCore.@scalar_rule( ) ChainRulesCore.@scalar_rule(dawson(x), 1 - (2 * x * Ω)) ChainRulesCore.@scalar_rule(digamma(x), trigamma(x)) -ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x * x)) -ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x * x)) +ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x^2)) +ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x^2)) +ChainRulesCore.@scalar_rule(logerfc(x), -(2 / sqrt(π)) * exp(-x^2 - Ω)) ChainRulesCore.@scalar_rule(erfcinv(x), -(sqrt(π) / 2) * exp(Ω^2)) ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π))) -ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x * x)) +ChainRulesCore.@scalar_rule(logerfcx(x), 2 * x - (2 / sqrt(π)) * exp(-Ω)) +ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x^2)) ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp(Ω^2)) + ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x)) ChainRulesCore.@scalar_rule( gamma(a, x), @@ -66,6 +71,7 @@ ChainRulesCore.@scalar_rule( ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x)) # binary +ChainRulesCore.@scalar_rule(erf(x, y), (-(2 / sqrt(π)) * exp(-x^2), (2 / sqrt(π)) * exp(-y^2))) ChainRulesCore.@scalar_rule( besselj(ν, x), ( @@ -94,6 +100,13 @@ ChainRulesCore.@scalar_rule( -(besselk(ν - 1, x) + besselk(ν + 1, x)) / 2, ), ) +ChainRulesCore.@scalar_rule( + besselkx(ν, x), + ( + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), + -(besselkx(ν - 1, x) + besselkx(ν + 1, x)) / 2 + Ω, + ), +) ChainRulesCore.@scalar_rule( hankelh1(ν, x), ( @@ -101,6 +114,13 @@ ChainRulesCore.@scalar_rule( (hankelh1(ν - 1, x) - hankelh1(ν + 1, x)) / 2, ), ) +ChainRulesCore.@scalar_rule( + hankelh1x(ν, x), + ( + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), + (hankelh1x(ν - 1, x) - hankelh1x(ν + 1, x)) / 2 - im * Ω, + ), +) ChainRulesCore.@scalar_rule( hankelh2(ν, x), ( @@ -108,6 +128,13 @@ ChainRulesCore.@scalar_rule( (hankelh2(ν - 1, x) - hankelh2(ν + 1, x)) / 2, ), ) +ChainRulesCore.@scalar_rule( + hankelh2x(ν, x), + ( + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), + (hankelh2x(ν - 1, x) - hankelh2x(ν + 1, x)) / 2 + im * Ω, + ), +) ChainRulesCore.@scalar_rule( polygamma(m, x), ( diff --git a/test/chainrules.jl b/test/chainrules.jl index 5c164b28..4dfd9b99 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -5,6 +5,7 @@ for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im) test_scalar(erf, x) test_scalar(erfc, x) + test_scalar(erfcx, x) test_scalar(erfi, x) test_scalar(airyai, x) @@ -12,10 +13,12 @@ test_scalar(airybi, x) test_scalar(airybiprime, x) - test_scalar(erfcx, x) test_scalar(dawson, x) if x isa Real + test_scalar(logerfc, x) + test_scalar(logerfcx, x) + test_scalar(invdigamma, x) end @@ -28,6 +31,11 @@ test_scalar(gamma, x) test_scalar(digamma, x) test_scalar(trigamma, x) + + if x isa Real + test_scalar(airyaix, x) + test_scalar(airyaiprimex, x) + end end end end @@ -51,31 +59,38 @@ test_frule(besselk, nu, x) test_rrule(besselk, nu, x) + test_frule(besselkx, nu, x) + test_rrule(besselkx, nu, x) test_frule(bessely, nu, x) test_rrule(bessely, nu, x) - # use complex numbers in `rrule` for FiniteDifferences test_frule(hankelh1, nu, x) - test_rrule(hankelh1, nu, complex(x)) + test_rrule(hankelh1, nu, x) + test_frule(hankelh1x, nu, x) + test_rrule(hankelh1x, nu, x) - # use complex numbers in `rrule` for FiniteDifferences test_frule(hankelh2, nu, x) - test_rrule(hankelh2, nu, complex(x)) + test_rrule(hankelh2, nu, x) + test_frule(hankelh2x, nu, x) + test_rrule(hankelh2x, nu, x) end end end - @testset "beta and logbeta" begin + @testset "erf, beta, and logbeta" begin test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im) - for _x in test_points, _y in test_points - # ensure all complex if any complex for FiniteDifferences - x, y = promote(_x, _y) + for x in test_points, y in test_points test_frule(beta, x, y) test_rrule(beta, x, y) test_frule(logbeta, x, y) test_rrule(logbeta, x, y) + + if x isa Real && y isa Real + test_frule(erf, x, y) + test_rrule(erf, x, y) + end end end @@ -91,13 +106,11 @@ isreal(x) && x < 0 && continue test_scalar(loggamma, x) for a in test_points - # ensure all complex if any complex for FiniteDifferences - _a, _x = promote(a, x) - test_frule(gamma, _a, _x; rtol=1e-8) - test_rrule(gamma, _a, _x; rtol=1e-8) + test_frule(gamma, a, x; rtol=1e-8) + test_rrule(gamma, a, x; rtol=1e-8) - test_frule(loggamma, _a, _x) - test_rrule(loggamma, _a, _x) + test_frule(loggamma, a, x) + test_rrule(loggamma, a, x) end isreal(x) || continue @@ -117,14 +130,11 @@ test_scalar(expintx, x) for nu in (-1.5, 2.2, 4.0) - # ensure all complex if any complex for FiniteDifferences - _x, _nu = promote(x, nu) - - test_frule(expint, _nu, _x) - test_rrule(expint, _nu, _x) + test_frule(expint, nu, x) + test_rrule(expint, nu, x) - test_frule(expintx, _nu, _x) - test_rrule(expintx, _nu, _x) + test_frule(expintx, nu, x) + test_rrule(expintx, nu, x) end isreal(x) || continue From 6c6a60ca2f06641d670585189914f23b13c0d362 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 22 Sep 2021 23:13:47 +0200 Subject: [PATCH 2/4] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dc4d34a0..115f7d6b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SpecialFunctions" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.6.2" +version = "1.7.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 02d5e645ecca1c06ae2a2c0e7cf50929bd0028bb Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 23 Sep 2021 16:17:07 +0200 Subject: [PATCH 3/4] Use irrational constants --- src/chainrules.jl | 25 ++++++++++++++----------- test/chainrules.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index b63c6eb6..a18993de 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -33,14 +33,17 @@ ChainRulesCore.@scalar_rule( ) ChainRulesCore.@scalar_rule(dawson(x), 1 - (2 * x * Ω)) ChainRulesCore.@scalar_rule(digamma(x), trigamma(x)) -ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x^2)) -ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x^2)) -ChainRulesCore.@scalar_rule(logerfc(x), -(2 / sqrt(π)) * exp(-x^2 - Ω)) -ChainRulesCore.@scalar_rule(erfcinv(x), -(sqrt(π) / 2) * exp(Ω^2)) -ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π))) -ChainRulesCore.@scalar_rule(logerfcx(x), 2 * x - (2 / sqrt(π)) * exp(-Ω)) -ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x^2)) -ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp(Ω^2)) + +# TODO: use `invsqrtπ` if it is added to IrrationalConstants +ChainRulesCore.@scalar_rule(erf(x), (2 * exp(-x^2)) / sqrtπ) +ChainRulesCore.@scalar_rule(erf(x, y), (- (2 * exp(-x^2)) / sqrtπ, (2 * exp(-y^2)) / sqrtπ)) +ChainRulesCore.@scalar_rule(erfc(x), - (2 * exp(-x^2)) / sqrtπ) +ChainRulesCore.@scalar_rule(logerfc(x), - (2 * exp(-x^2 - Ω)) / sqrtπ) +ChainRulesCore.@scalar_rule(erfcinv(x), - (sqrtπ * (exp(Ω^2) / 2))) +ChainRulesCore.@scalar_rule(erfcx(x), 2 * x * Ω - 2 / sqrtπ) +ChainRulesCore.@scalar_rule(logerfcx(x), 2 * (x - exp(-Ω) / sqrtπ)) +ChainRulesCore.@scalar_rule(erfi(x), (2 * exp(x^2)) / sqrtπ) +ChainRulesCore.@scalar_rule(erfinv(x), sqrtπ * (exp(Ω^2) / 2)) ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x)) ChainRulesCore.@scalar_rule( @@ -70,8 +73,7 @@ ChainRulesCore.@scalar_rule( ) ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x)) -# binary -ChainRulesCore.@scalar_rule(erf(x, y), (-(2 / sqrt(π)) * exp(-x^2), (2 / sqrt(π)) * exp(-y^2))) +# Bessel functions ChainRulesCore.@scalar_rule( besselj(ν, x), ( @@ -135,6 +137,7 @@ ChainRulesCore.@scalar_rule( (hankelh2x(ν - 1, x) - hankelh2x(ν + 1, x)) / 2 + im * Ω, ), ) + ChainRulesCore.@scalar_rule( polygamma(m, x), ( @@ -188,5 +191,5 @@ ChainRulesCore.@scalar_rule( ) ) ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x) -ChainRulesCore.@scalar_rule(sinint(x), sinc(x / π)) +ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x)) ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x) diff --git a/test/chainrules.jl b/test/chainrules.jl index 4dfd9b99..5b881766 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -143,4 +143,31 @@ test_scalar(cosint, x) end end + + # https://github.com/JuliaMath/SpecialFunctions.jl/issues/307 + @testset "promotions" begin + # one argument + for f in (erf, erfc, logerfc, erfcinv, logerfcx, erfi, erfinv, sinint) + _, ẏ = frule((NoTangent(), 1f0), f, 1f0) + @test ẏ isa Float32 + _, back = rrule(f, 1f0) + _, x̄ = back(1f0) + @test x̄ isa Float32 + end + + # two arguments + _, ẏ = frule((NoTangent(), 1f0, 1f0), erf, 1f0, 1f0) + @test ẏ isa Float32 + _, back = rrule(erf, 1f0, 1f0) + _, x̄ = back(1f0) + @test x̄ isa Float32 + + # currently broken, can be fixed if `invsqrtπ` is available: + # https://github.com/JuliaMath/IrrationalConstants.jl/pull/8#issuecomment-925828753 + _, ẏ = frule((NoTangent(), 1f0), erfcx, 1f0) + @test_broken ẏ isa Float32 + _, back = rrule(erfcx, 1f0) + _, x̄ = back(1f0) + @test x̄ isa Float32 + end end From 661ee359349fbdfe9aa96128641515512e3583c2 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 24 Sep 2021 08:25:58 +0200 Subject: [PATCH 4/4] Convert irrational manually with `oftype` --- src/chainrules.jl | 2 +- test/chainrules.jl | 10 +--------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index a18993de..fa7b5dd3 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -40,7 +40,7 @@ ChainRulesCore.@scalar_rule(erf(x, y), (- (2 * exp(-x^2)) / sqrtπ, (2 * exp(-y^ ChainRulesCore.@scalar_rule(erfc(x), - (2 * exp(-x^2)) / sqrtπ) ChainRulesCore.@scalar_rule(logerfc(x), - (2 * exp(-x^2 - Ω)) / sqrtπ) ChainRulesCore.@scalar_rule(erfcinv(x), - (sqrtπ * (exp(Ω^2) / 2))) -ChainRulesCore.@scalar_rule(erfcx(x), 2 * x * Ω - 2 / sqrtπ) +ChainRulesCore.@scalar_rule(erfcx(x), 2 * (x * Ω - inv(oftype(Ω, sqrtπ)))) ChainRulesCore.@scalar_rule(logerfcx(x), 2 * (x - exp(-Ω) / sqrtπ)) ChainRulesCore.@scalar_rule(erfi(x), (2 * exp(x^2)) / sqrtπ) ChainRulesCore.@scalar_rule(erfinv(x), sqrtπ * (exp(Ω^2) / 2)) diff --git a/test/chainrules.jl b/test/chainrules.jl index 5b881766..d4be8285 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -147,7 +147,7 @@ # https://github.com/JuliaMath/SpecialFunctions.jl/issues/307 @testset "promotions" begin # one argument - for f in (erf, erfc, logerfc, erfcinv, logerfcx, erfi, erfinv, sinint) + for f in (erf, erfc, logerfc, erfcinv, erfcx, logerfcx, erfi, erfinv, sinint) _, ẏ = frule((NoTangent(), 1f0), f, 1f0) @test ẏ isa Float32 _, back = rrule(f, 1f0) @@ -161,13 +161,5 @@ _, back = rrule(erf, 1f0, 1f0) _, x̄ = back(1f0) @test x̄ isa Float32 - - # currently broken, can be fixed if `invsqrtπ` is available: - # https://github.com/JuliaMath/IrrationalConstants.jl/pull/8#issuecomment-925828753 - _, ẏ = frule((NoTangent(), 1f0), erfcx, 1f0) - @test_broken ẏ isa Float32 - _, back = rrule(erfcx, 1f0) - _, x̄ = back(1f0) - @test x̄ isa Float32 end end