Skip to content

Commit a9b1790

Browse files
committed
Add sufficient statistics and MLE for Chi and Chisq
1 parent a1b8bb3 commit a9b1790

File tree

5 files changed

+104
-3
lines changed

5 files changed

+104
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Distributions"
22
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
authors = ["JuliaStats"]
4-
version = "0.25.120"
4+
version = "0.25.121"
55

66
[deps]
77
AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"

docs/src/fit.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ The `fit_mle` method has been implemented for the following distributions:
4343
- [`Beta`](@ref)
4444
- [`Binomial`](@ref)
4545
- [`Categorical`](@ref)
46+
- [`Chi`](@ref)
47+
- [`Chisq`](@ref)
4648
- [`DiscreteUniform`](@ref)
4749
- [`Exponential`](@ref)
4850
- [`LogNormal`](@ref)

src/univariate/continuous/chi.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ External links
2323
"""
2424
struct Chi{T<:Real} <: ContinuousUnivariateDistribution
2525
ν::T
26-
Chi{T}::T) where {T} = new{T}(ν)
26+
Chi{T}::Real) where {T<:Real} = new{T}(ν)
2727
end
2828

2929
function Chi::Real; check_args::Bool=true)
@@ -119,3 +119,22 @@ end
119119
rand(rng::AbstractRNG, s::ChiSampler) = sqrt(rand(rng, s.s))
120120

121121
sampler(d::Chi) = ChiSampler(sampler(Chisq(d.ν)))
122+
123+
124+
#### Fitting
125+
126+
struct ChiStats{T<:Real} <: SufficientStats
127+
# (Weighted) mean of log(x)
128+
mlogx::T
129+
end
130+
131+
suffstats(::Type{<:Chi}, x::AbstractArray{<:Real}) = ChiStats(mean(log, x))
132+
function suffstats(::Type{<:Chi}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
133+
if axes(x) != axes(w)
134+
throw(DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal."))
135+
end
136+
mlogx = sum(Broadcast.instantiate(Broadcast.broadcasted(xlogy, w, x))) / sum(w)
137+
return ChiStats(mlogx)
138+
end
139+
140+
fit_mle(::Type{T}, ss::ChiStats) where {T<:Chi} = T(2 * invdigamma(2 * ss.mlogx - logtwo))

src/univariate/continuous/chisq.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ External links
2222
"""
2323
struct Chisq{T<:Real} <: ContinuousUnivariateDistribution
2424
ν::T
25-
Chisq{T}::T) where {T} = new{T}(ν)
25+
Chisq{T}::Real) where {T<:Real} = new{T}(ν)
2626
end
2727

2828
function Chisq::Real; check_args::Bool=true)
@@ -107,3 +107,21 @@ function sampler(d::Chisq)
107107
θ = oftype(α, 2)
108108
return sampler(Gamma{typeof(α)}(α, θ))
109109
end
110+
111+
#### Fitting
112+
113+
struct ChisqStats{T<:Real} <: SufficientStats
114+
# (Weighted) mean of log(x)
115+
mlogx::T
116+
end
117+
118+
suffstats(::Type{<:Chisq}, x::AbstractArray{<:Real}) = ChisqStats(mean(log, x))
119+
function suffstats(::Type{<:Chisq}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
120+
if axes(x) != axes(w)
121+
throw(DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal."))
122+
end
123+
mlogx = sum(Broadcast.instantiate(Broadcast.broadcasted(xlogy, w, x))) / sum(w)
124+
return ChisqStats(mlogx)
125+
end
126+
127+
fit_mle(::Type{T}, ss::ChisqStats) where {T<:Chisq} = T(2 * invdigamma(ss.mlogx - logtwo))

test/fit.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
using Distributions
88
using OffsetArrays
9+
using ForwardDiff
910
using Test, Random, LinearAlgebra
1011

1112

@@ -465,3 +466,64 @@ end
465466

466467
end
467468
end
469+
470+
@testset "Testing fit for Chi" begin
471+
ν = 3.1
472+
for func in funcs, D in (Chi, Chi{Float64}, Chi{Float32})
473+
v = func[1](n0)
474+
z = func[2](D(ν), n0)
475+
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
476+
ss = @inferred suffstats(D, x)
477+
@test ss isa Distributions.ChiStats
478+
@test ss.mlogx mean(log.(x))
479+
480+
d = @inferred fit(D, x)
481+
@test d isa D
482+
@test ForwardDiff.derivative-> sum(logpdf.(Chi(ν), x)), dof(d)) 0 atol = (eps(partype(d)))^(2/3)
483+
484+
if axes(x) == axes(w)
485+
d = @inferred fit(D, x, w)
486+
@test d isa D
487+
@test ForwardDiff.derivative-> dot(logpdf.(Chi(ν), x), w), dof(d)) 0 atol = (eps(partype(d)))^(2/3)
488+
489+
ss = @inferred suffstats(D, x, w)
490+
@test ss isa Distributions.ChiStats
491+
@test ss.mlogx dot(w ./ sum(w), log.(x))
492+
else
493+
@test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") suffstats(D, x, w)
494+
@test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") fit(D, x, w)
495+
end
496+
end
497+
end
498+
end
499+
500+
501+
@testset "Testing fit for Chisq" begin
502+
ν = 4.3
503+
for func in funcs, D in (Chisq, Chisq{Float64}, Chisq{Float32})
504+
v = func[1](n0)
505+
z = func[2](D(ν), n0)
506+
for x in (z, OffsetArray(z, -n0 ÷ 2)), w in (v, OffsetArray(v, -n0 ÷ 2))
507+
ss = @inferred suffstats(D, x)
508+
@test ss isa Distributions.ChisqStats
509+
@test ss.mlogx mean(log.(x))
510+
511+
d = @inferred fit(D, x)
512+
@test d isa D
513+
@test ForwardDiff.derivative-> sum(logpdf.(Chisq(ν), x)), dof(d)) 0 atol = (eps(partype(d)))^(2/3)
514+
515+
if axes(x) == axes(w)
516+
ss = @inferred suffstats(D, x, w)
517+
@test ss isa Distributions.ChisqStats
518+
@test ss.mlogx dot(w ./ sum(w), log.(x))
519+
520+
d = @inferred fit(D, x, w)
521+
@test d isa D
522+
@test ForwardDiff.derivative-> dot(logpdf.(Chisq(ν), x), w), dof(d)) 0 atol = (eps(partype(d)))^(2/3)
523+
else
524+
@test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") suffstats(D, x, w)
525+
@test_throws DimensionMismatch("Inconsistent array dimensions: Axes of samples and sample weights must be equal.") fit(D, x, w)
526+
end
527+
end
528+
end
529+
end

0 commit comments

Comments
 (0)