Skip to content

Commit 0032b9e

Browse files
committed
Add sufficient statistics and MLE for Chi and Chisq
1 parent a1b8bb3 commit 0032b9e

File tree

4 files changed

+118
-3
lines changed

4 files changed

+118
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Distributions"
22
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
authors = ["JuliaStats"]
4-
version = "0.25.120"
4+
version = "0.25.121"
55

66
[deps]
77
AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"

src/univariate/continuous/chi.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ External links
2323
"""
2424
struct Chi{T<:Real} <: ContinuousUnivariateDistribution
2525
ν::T
26-
Chi{T}::T) where {T} = new{T}(ν)
26+
Chi{T}::Real) where {T<:Real} = new{T}(ν)
2727
end
2828

2929
function Chi::Real; check_args::Bool=true)
@@ -119,3 +119,28 @@ end
119119
rand(rng::AbstractRNG, s::ChiSampler) = sqrt(rand(rng, s.s))
120120

121121
sampler(d::Chi) = ChiSampler(sampler(Chisq(d.ν)))
122+
123+
124+
#### Fitting
125+
126+
struct ChiStats{T<:Real} <: SufficientStats
127+
# (Weighted) sum of log(x)
128+
slogx::T
129+
# Sum of sample weights
130+
sw::T
131+
end
132+
133+
ChiStats(slogx::Real, sw::Real) = ChiStats(promote(slogx, sw)...)
134+
135+
suffstats(::Type{<:Chi}, x::AbstractArray{<:Real}) = ChiStats(mean(log, x), 1)
136+
function suffstats(::Type{<:Chi}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
137+
if axes(x) != axes(w)
138+
throw(DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal."))
139+
end
140+
ChiStats(
141+
sum(Broadcast.instantiate(Broadcast.broadcasted(xlogy, w, x))),
142+
sum(w),
143+
)
144+
end
145+
146+
fit_mle(::Type{T}, ss::ChiStats) where {T<:Chi} = T(2 * invdigamma(2 * ss.slogx / ss.sw - logtwo))

src/univariate/continuous/chisq.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ External links
2222
"""
2323
struct Chisq{T<:Real} <: ContinuousUnivariateDistribution
2424
ν::T
25-
Chisq{T}::T) where {T} = new{T}(ν)
25+
Chisq{T}::Real) where {T<:Real} = new{T}(ν)
2626
end
2727

2828
function Chisq::Real; check_args::Bool=true)
@@ -107,3 +107,27 @@ function sampler(d::Chisq)
107107
θ = oftype(α, 2)
108108
return sampler(Gamma{typeof(α)}(α, θ))
109109
end
110+
111+
#### Fitting
112+
113+
struct ChisqStats{T<:Real} <: SufficientStats
114+
# (Weighted) sum of log(x)
115+
slogx::T
116+
# Sum of sample weights
117+
sw::T
118+
end
119+
120+
ChisqStats(slogx::Real, sw::Real) = ChisqStats(promote(slogx, sw)...)
121+
122+
suffstats(::Type{<:Chisq}, x::AbstractArray{<:Real}) = ChisqStats(mean(log, x), 1)
123+
function suffstats(::Type{<:Chisq}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
124+
if axes(x) != axes(w)
125+
throw(DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal."))
126+
end
127+
ChisqStats(
128+
sum(Broadcast.instantiate(Broadcast.broadcasted(xlogy, w, x))),
129+
sum(w),
130+
)
131+
end
132+
133+
fit_mle(::Type{T}, ss::ChisqStats) where {T<:Chisq} = T(2 * invdigamma(ss.slogx / ss.sw - logtwo))

test/fit.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
using Distributions
88
using OffsetArrays
9+
using ForwardDiff
910
using Test, Random, LinearAlgebra
1011

1112

@@ -465,3 +466,68 @@ end
465466

466467
end
467468
end
469+
470+
@testset "Testing fit for Chi" begin
471+
ν = 3.1
472+
for func in funcs, D in (Chi, Chi{Float64}, Chi{Float32})
473+
v = func[1](n0)
474+
z = func[2](D(ν), n0)
475+
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
476+
ss = @inferred suffstats(D, x)
477+
@test ss isa Distributions.ChiStats
478+
@test ss.slogx mean(log.(x))
479+
@test ss.sw == 1.0
480+
481+
d = @inferred fit(D, x)
482+
@test d isa D
483+
@test ForwardDiff.derivative-> sum(logpdf.(Chi(ν), x)), dof(d)) 0 atol = (eps(partype(d)))^(2/3)
484+
485+
if axes(x) == axes(w)
486+
d = @inferred fit(D, x, w)
487+
@test d isa D
488+
@test ForwardDiff.derivative-> dot(logpdf.(Chi(ν), x), w), dof(d)) 0 atol = (eps(partype(d)))^(2/3)
489+
490+
ss = @inferred suffstats(D, x, w)
491+
@test ss isa Distributions.ChiStats
492+
@test ss.slogx dot(w, log.(x))
493+
@test ss.sw == sum(w)
494+
else
495+
@test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") suffstats(D, x, w)
496+
@test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") fit(D, x, w)
497+
end
498+
end
499+
end
500+
end
501+
502+
503+
@testset "Testing fit for Chisq" begin
504+
ν = 4.3
505+
for func in funcs, D in (Chisq, Chisq{Float64}, Chisq{Float32})
506+
v = func[1](n0)
507+
z = func[2](D(ν), n0)
508+
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
509+
ss = @inferred suffstats(D, x)
510+
@test ss isa Distributions.ChisqStats
511+
@test ss.slogx mean(log.(x))
512+
@test ss.sw == 1.0
513+
514+
d = @inferred fit(D, x)
515+
@test d isa D
516+
@test ForwardDiff.derivative-> sum(logpdf.(Chisq(ν), x)), dof(d)) 0 atol = (eps(partype(d)))^(2/3)
517+
518+
if axes(x) == axes(w)
519+
ss = @inferred suffstats(D, x, w)
520+
@test ss isa Distributions.ChisqStats
521+
@test ss.slogx dot(w, log.(x))
522+
@test ss.sw == sum(w)
523+
524+
d = @inferred fit(D, x, w)
525+
@test d isa D
526+
@test ForwardDiff.derivative-> dot(logpdf.(Chisq(ν), x), w), dof(d)) 0 atol = (eps(partype(d)))^(2/3)
527+
else
528+
@test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") suffstats(D, x, w)
529+
@test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") fit(D, x, w)
530+
end
531+
end
532+
end
533+
end

0 commit comments

Comments
 (0)