Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix evalpoly type instability and 0-length case #56707

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
28 changes: 14 additions & 14 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import .Base: log, exp, sin, cos, tan, sinh, cosh, tanh, asin,
using .Base: sign_mask, exponent_mask, exponent_one,
exponent_half, uinttype, significand_mask,
significand_bits, exponent_bits, exponent_bias,
exponent_max, exponent_raw_max, clamp, clamp!
exponent_max, exponent_raw_max, clamp, clamp!, reduce_empty_iter

using Core.Intrinsics: sqrt_llvm

Expand Down Expand Up @@ -94,7 +94,7 @@ julia> evalpoly(2, (1, 2, 3))
function evalpoly(x, p::Tuple)
if @generated
N = length(p.parameters::Core.SimpleVector)
ex = :(p[end])
ex = :(oftype(one(x) * p[end], p[end]))
Copy link
Contributor

@MasonProtter MasonProtter Feb 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry a bit about this causing performance regressions (especially in cases where multiplication is not cheap), but I also get that it's likely important.

Ideally we'd be able to do something in the type domain here rather than an actual multiplication, but oh well i guess. This is probably less brittle, and won't regress in most cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully the compiler can optimize out one(x) * p[end] in most cases, and just compute the type, since the result is not used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if the multiplication is simple enough and free of side effects like memory allocation? If the scalar is eg a BigInt it may not be able to optimize this away?

for i in N-1:-1:1
ex = :(muladd(x, $ex, p[$i]))
end
Expand All @@ -103,20 +103,23 @@ function evalpoly(x, p::Tuple)
_evalpoly(x, p)
end
end
evalpoly(x, ::Tuple{}) = zero(one(x)) # dimensionless zero, i.e. 0 * x^0

evalpoly(x, p::AbstractVector) = _evalpoly(x, p)

function _evalpoly(x, p)
Base.require_one_based_indexing(p)
N = length(p)
ex = p[end]
p0 = iszero(N) ? reduce_empty_iter(+, p) : @inbounds p[N]
s = oftype(one(x) * p0, p0)
for i in N-1:-1:1
ex = muladd(x, ex, p[i])
s = muladd(x, s, @inbounds p[i])
end
ex
return s
end

function evalpoly(z::Complex, p::Tuple)
# Goertzel-like algorithm from Knuth, TAOCP vol. 2, section 4.6.4:
function evalpoly(z::Complex, p::Tuple{Any, Any, Vararg})
if @generated
N = length(p.parameters)
a = :(p[end])
Expand All @@ -141,17 +144,14 @@ function evalpoly(z::Complex, p::Tuple)
_evalpoly(z, p)
end
end
evalpoly(z::Complex, p::Tuple{<:Any}) = p[1]


evalpoly(z::Complex, p::AbstractVector) = _evalpoly(z, p)

function _evalpoly(z::Complex, p)
Base.require_one_based_indexing(p)
length(p) == 1 && return p[1]
N = length(p)
a = p[end]
b = p[end-1]
p0 = iszero(N) ? reduce_empty_iter(+, p) : @inbounds p[N]
N <= 1 && return oftype(one(z) * p0, p0)
a = p0
@inbounds b = p[N-1]

x = real(z)
y = imag(z)
Expand All @@ -160,7 +160,7 @@ function _evalpoly(z::Complex, p)
for i in N-2:-1:1
ai = a
a = muladd(r, ai, b)
b = muladd(-s, ai, p[i])
b = muladd(-s, ai, @inbounds p[i])
end
ai = a
muladd(ai, z, b)
Expand Down
19 changes: 19 additions & 0 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,10 @@ end
c = 3
@test @evalpoly(c, a0, a1) == 7
@test @evalpoly(1, 2) == 2

isdefined(Main, :Furlongs) || @eval Main include("testhelpers/Furlongs.jl")
using .Main.Furlongs
@test @evalpoly(Furlong(2)) === evalpoly(Furlong(2), ()) === evalpoly(Furlong(2), Int[]) === 0
end

@testset "evalpoly real" begin
Expand All @@ -715,6 +719,13 @@ end
@test evalpoly(x, (p1, p2, p3)) == evpm
@test evalpoly(x, [p1, p2, p3]) == evpm
end
@test evalpoly(1.0f0, ()) === 0.0f0 # issue #56699
@test @inferred(evalpoly(1.0f0, Int[])) === 0.0f0 # issue #56699
@test_throws MethodError evalpoly(1.0f0, [])
@test @inferred(evalpoly(1.0f0, [2])) === 2.0f0 # type-stability

# different @generated branches should return same type:
@test evalpoly(3.0, (1,)) === Base.Math._evalpoly(3.0, (1,)) === 1.0
end

@testset "evalpoly complex" begin
Expand All @@ -726,6 +737,14 @@ end
end
@test evalpoly(1+im, (2,)) == 2
@test evalpoly(1+im, [2,]) == 2
@test evalpoly(1.0f0+im, ()) === 0.0f0+0im # issue #56699
@test @inferred(evalpoly(1.0f0+im, Int[])) === 0.0f0+0im # issue #56699
@test_throws MethodError evalpoly(1.0f0, [])
@test @inferred(evalpoly(1.0f0+im, [2])) === 2.0f0+0im # type-stability

# different @generated branches should return same type:
@test evalpoly(3.0+0im, (1,)) === Base.Math._evalpoly(3.0+0im, (1,)) === 1.0+0im
@test evalpoly(3.0+0im, (1,2)) === Base.Math._evalpoly(3.0+0im, (1,2)) === 7.0+0im
end

@testset "cis" begin
Expand Down
Loading