Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
dd6376e
create mvhypergeomtric.jl
Michael-Howes Apr 8, 2025
8bcfaba
add MvHypergeometric to module
Michael-Howes Apr 8, 2025
7f494af
add multivariate hypergeometric sampler
Michael-Howes Apr 9, 2025
47e416d
add tests
Michael-Howes Apr 9, 2025
585caff
adding tests
Michael-Howes Apr 9, 2025
43e824a
add tests comparing multivariate hypergeometric to univariate hyperge…
Michael-Howes Apr 9, 2025
53b005f
add comments to mvhypergeometric sampler
Michael-Howes Apr 9, 2025
f2d6b4f
rename test file for multivariate hypergeometric
Michael-Howes Apr 9, 2025
4bdaca4
append mvhypergeometric tests to runtest.jl
Michael-Howes Apr 9, 2025
d17c249
update tests
Michael-Howes Apr 9, 2025
9e37f95
update project.toml
Michael-Howes Apr 9, 2025
91d9da5
update tests to improve coverage
Michael-Howes Apr 9, 2025
74240a6
Merge branch 'JuliaStats:master' into master
Michael-Howes Oct 7, 2025
88878b4
rename tests
Michael-Howes Oct 7, 2025
e6b67d7
Update file names
Michael-Howes Oct 8, 2025
f7b61ff
Update documentation
Michael-Howes Oct 8, 2025
58f992d
Improve test coverage
Michael-Howes Oct 8, 2025
65c83f7
Update sampling function to remove redundant logic.
Michael-Howes Oct 9, 2025
dbeeb04
formate with julia vscode extension
Michael-Howes Oct 11, 2025
30c7882
reformat
Michael-Howes Oct 11, 2025
fc068db
Merge branch 'master' into master
Michael-Howes Nov 5, 2025
3f5985c
Remove nelements functions
Michael-Howes Nov 10, 2025
a5f8cc0
Remove nelements from test
Michael-Howes Nov 10, 2025
17ee592
remove use of nelements
Michael-Howes Nov 10, 2025
a01a3e1
Add explicitly defined partype
Michael-Howes Nov 10, 2025
4d2ed35
Remove abstractly typed fields
Michael-Howes Nov 10, 2025
c0a6cc0
Remove AbstractVector{Int{
Michael-Howes Nov 10, 2025
5fb20ff
Remove where clause in type signature
Michael-Howes Nov 10, 2025
4d58d63
Remove where clause in type signature
Michael-Howes Nov 10, 2025
b0c2386
Optimize @check_args
Michael-Howes Nov 10, 2025
53f875a
Improve var computation
Michael-Howes Nov 10, 2025
7042107
Change cov computation
Michael-Howes Nov 10, 2025
947fed3
Optimize insupport
Michael-Howes Nov 10, 2025
d096ba5
Remove inbounds from _logpdf
Michael-Howes Nov 10, 2025
20d68b8
Remove inbounds from sampling
Michael-Howes Nov 10, 2025
4853cbc
Remove inbounds from sampling
Michael-Howes Nov 10, 2025
84dfbb6
Move sampling method into main file
Michael-Howes Nov 10, 2025
cccd79f
Update formatting
Michael-Howes Nov 10, 2025
6c20f7e
Add partype test
Michael-Howes Nov 10, 2025
b3496ab
Check support during compution of _logpdf
Michael-Howes Nov 10, 2025
26e33f7
Update docstring
Michael-Howes Nov 10, 2025
afe12d1
Update support check in _logpdf
Michael-Howes Nov 11, 2025
6b0ad57
Merge branch 'JuliaStats:master' into master
Michael-Howes Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/multivariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ MvNormalCanon
MvLogitNormal
MvLogNormal
Dirichlet
MvHypergeometric
```

## Addition Methods
Expand Down
1 change: 1 addition & 0 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ export
MatrixTDist,
MixtureModel,
Multinomial,
MvHypergeometric,
MultivariateNormal,
MvLogNormal,
MvNormal,
Expand Down
131 changes: 131 additions & 0 deletions src/multivariate/mvhypergeometric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
MvHypergeometric(m, n)

The [Multivariate hypergeometric distribution](https://en.wikipedia.org/wiki/Hypergeometric_distribution#Multivariate_hypergeometric_distribution)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Julia docstrings should contain the type or function signature in the first line, indented by four spaces.

Generally, could you make this docstring consistent with docstrings of other existing distribution types?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the function signature. The doc string is based on the multinomial distribution docstring. Is there a different distribution I should be following?

https://github.com/JuliaStats/Distributions.jl/blob/0421b1890c49417de993b255f6b177a1cb93fca3/src/multivariate/multinomial.jl

generalizes the *hypergeometric distribution*. Consider ``n`` draws from a finite population containing ``k`` types of elements. Suppose that the population has size ``M`` and there are ``m_i`` elements of type ``i`` for ``i = 1, .., k`` with ``m_1+...m_k = M``. Let ``X = (X_1, ..., X_k)`` where ``X_i`` represents the number of elements of type ``i`` drawn, then the distribution of ``X`` is a multivariate hypergeometric distribution. Each sample of a multivariate hypergeometric distribution is a ``k``-dimensional integer vector that sums to ``n`` and satisfies ``0 \\le X_i \\le m_i``.


The probability mass function is given by

```math
f(x; m, n) = {{{m_1 \\choose x_1}{m_2 \\choose x_2}\\cdots {m_k \\choose x_k}}\\over {N \\choose n}},
\\quad x_1 + \\cdots + x_k = n, \\quad 0 \\le x_i \\le m_i
```

```julia
MvHypergeometric(m, n) # Multivariate hypergeometric distribution for a population with
# m[i] elements of types i and n draws
```
"""
struct MvHypergeometric <: DiscreteMultivariateDistribution
m::Vector{Int} # number of elements of each type
n::Int # number of draws
function MvHypergeometric(m::Vector{Int}, n::Int; check_args::Bool=true)
@check_args(
MvHypergeometric,
(m, all(x -> x >= 0, m)),
zero(n) <= n <= sum(m),
)
new(m, n)
end
end


# Parameters

ncategories(d::MvHypergeometric) = length(d.m)
length(d::MvHypergeometric) = ncategories(d)
ntrials(d::MvHypergeometric) = d.n

params(d::MvHypergeometric) = (d.m, d.n)
partype(::MvHypergeometric) = Int

# Statistics

mean(d::MvHypergeometric) = d.n .* d.m ./ sum(d.m)

function var(d::MvHypergeometric)
m = d.m
n = ntrials(d)
M = sum(m)
f = n * (M - n) / (M - 1)
v = let f = f
map(mi -> f * (mi / M) * ((M - mi) / M), m)
end
v
end

function cov(d::MvHypergeometric)
m = d.m
n = ntrials(d)
M = sum(m)
p = m / M
f = n * (M - n) / (M - 1)

C = -f * (p * p')
C[diagind(C)] .= f .* p .* (1 .- p)

C
end


# Evaluation
function insupport(d::MvHypergeometric, x::AbstractVector{<:Real})
return length(x) == length(d) && (eltype(x) <: Integer || all(isinteger, x)) && all(((xi, mi),) -> zero(xi) <= xi <= mi, zip(x, d.m)) && sum(x) == ntrials(d)
end

function _logpdf(d::MvHypergeometric, x::AbstractVector{<:Real})
m = d.m
M = sum(m)
n = ntrials(d)

s = -logabsbinomial(M, n)[1]
for i = 1:length(m)
xi = x[i]
mi = m[i]
((typeof(xi) <: Integer || isinteger(xi)) && (zero(xi) <= xi <= mi)) || return -Float64(Inf)
s += logabsbinomial(mi, xi)[1]
n -= xi
end
(n == 0) || return -Float64(Inf)
return s
end

# Sampling is performed by sequentially sampling each entry from the
# hypergeometric distribution
function _rand!(rng::AbstractRNG, d::MvHypergeometric, x::AbstractVector{<:Real})
k = length(d)
n = ntrials(d)
m = d.m
length(x) == k || throw(DimensionMismatch("Invalid argument dimension."))

M = sum(m)
i = 0
km1 = k - 1

while i < km1 && n > 0
i += 1
mi = m[i]
# Sample from hypergeometric distribution. Element of type i are
# considered successes. All other elements are considered failures.
xi = rand(rng, Hypergeometric(mi, M - mi, n))
x[i] = xi
# Remove elements of type i from population and group to be sampled.
n -= xi
M -= mi
end

if i == km1
x[k] = n
else # n must have been zero.
z = zero(eltype(x))
for j = i+1:k
x[j] = z
end
end

return x
end



3 changes: 2 additions & 1 deletion src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ for fname in ["dirichlet.jl",
"mvlognormal.jl",
"mvtdist.jl",
"product.jl", # deprecated
"vonmisesfisher.jl"]
"vonmisesfisher.jl",
"mvhypergeometric.jl"]
include(joinpath("multivariate", fname))
end
135 changes: 135 additions & 0 deletions test/multivariate/mvhypergeometric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Tests for Multivariate Hypergeometric

using Distributions
using Test

@testset "Multivariate Hypergeometric" begin
@test_throws DomainError MvHypergeometric([5, 3, -2], 4)
@test_throws ArgumentError MvHypergeometric([5, 3], 10)

m = [5, 3, 2]
n = 4
d = MvHypergeometric(m, n)
@test length(d) == 3
@test d.n == n
@test d.m == m
@test ncategories(d) == length(m)
@test params(d) == (m, n)
@test partype(d) == Int

@test mean(d) ≈ [2.0, 1.2, 0.8]
@test var(d) ≈ [2 / 3, 56 / 100, 32 / 75]

covmat = cov(d)
@test covmat ≈ (8 / 3) .* [1/4 -3/20 -1/10; -3/20 21/100 -3/50; -1/10 -3/50 4/25]

@test insupport(d, [2, 1, 1])
@test !insupport(d, [3, 2, 1])
@test !insupport(d, [0, 0, 4])


# random sampling
x = rand(d)
@test isa(x, Vector{Int})
@test sum(x) == n
@test all(x .>= 0)
@test all(x .<= m)
@test insupport(d, x)

x = rand(d, 100)
@test isa(x, Matrix{Int})
@test all(sum(x, dims=1) .== n)
@test all(x .>= 0)
@test all(x .<= m)
@test all(insupport(d, x))

# random sampling with many catergories
m = [20, 2, 2, 2, 1, 1, 1]
n = 5
d2 = MvHypergeometric(m, n)
x = rand(d2)
@test isa(x, Vector{Int})
@test sum(x) == n
@test all(x .>= 0)
@test all(x .<= m)
@test insupport(d2, x)

# random sampling with a large category
m = [2, 1000]
n = 5
d3 = MvHypergeometric(m, n)
x = rand(d3)
@test isa(x, Vector{Int})
@test sum(x) == n
@test all(x .>= 0)
@test all(x .<= m)
@test insupport(d3, x)

# log pdf
x = [2, 1, 1]
@test pdf(d, x) ≈ 2 / 7
@test logpdf(d, x) ≈ log(2 / 7)
@test logpdf(d, x) ≈ log(pdf(d, x))
@test logpdf(d, [2.5, 0.5, 1]) == -Inf

x = rand(d, 100)
pv = pdf(d, x)
lp = logpdf(d, x)
for i in 1:size(x, 2)
@test pv[i] ≈ pdf(d, x[:, i])
@test lp[i] ≈ logpdf(d, x[:, i])
end

# test degenerate cases
d1 = MvHypergeometric([1], 1)
@test logpdf(d1, [1]) ≈ 0
@test logpdf(d1, [0]) == -Inf
d2 = MvHypergeometric([2, 0], 1)
@test logpdf(d2, [1, 0]) ≈ 0
@test logpdf(d2, [0, 1]) == -Inf

d3 = MvHypergeometric([5, 0, 0, 0], 3)
@test logpdf(d3, [3, 0, 0, 0]) ≈ 0
@test logpdf(d3, [2, 1, 0, 0]) == -Inf
@test logpdf(d3, [2, 0, 0, 0]) == -Inf

# behavior with n = 0
d0 = MvHypergeometric([5, 3, 2], 0)
@test logpdf(d0, [0, 0, 0]) ≈ 0
@test logpdf(d0, [1, 0, 0]) == -Inf

@test rand(d0) == [0, 0, 0]
@test mean(d0) == [0.0, 0.0, 0.0]
@test var(d0) == [0.0, 0.0, 0.0]
@test insupport(d0, [0, 0, 0])
@test !insupport(d0, [1, 0, 0])
@test length(d0) == 3

# compare with hypergeometric
ns = 3
nf = 5
n = 4
dh1 = MvHypergeometric([ns, nf], n)
dh2 = Hypergeometric(ns, nf, n)

x = 2
@test pdf(dh1, [x, n - x]) ≈ pdf(dh2, x)
x = 3
@test pdf(dh1, [x, n - x]) ≈ pdf(dh2, x)

# comparing marginals to hypergeometric
m = [5, 3, 2]
n = 4
d = MvHypergeometric(m, n)
dh = Hypergeometric(m[1], sum(m[2:end]), n)
x1 = 2
@test pdf(dh, x1) ≈ sum([pdf(d, [x1, x2, n - x1 - x2]) for x2 in 0:m[2]])

# comparing conditionals to hypergeometric
x1 = 2
dh = Hypergeometric(m[2], m[3], n - x1)
q = sum([pdf(d, [x1, x2, n - x1 - x2]) for x2 in 0:m[2]])
for x2 = 0:m[2]
@test pdf(dh, x2) ≈ pdf(d, [x1, x2, n - x1 - x2]) / q
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ const tests = [
"univariate/continuous/triangular",
"statsapi",
"univariate/continuous/inversegaussian",
"multivariate/mvhypergeometric",

### missing files compared to /src:
# "common",
Expand Down
Loading