-
Notifications
You must be signed in to change notification settings - Fork 432
Add multivariate hypergeometric distribution #1963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Michael-Howes
wants to merge
43
commits into
JuliaStats:master
Choose a base branch
from
Michael-Howes:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
dd6376e
create mvhypergeomtric.jl
Michael-Howes 8bcfaba
add MvHypergeometric to module
Michael-Howes 7f494af
add multivariate hypergeometric sampler
Michael-Howes 47e416d
add tests
Michael-Howes 585caff
adding tests
Michael-Howes 43e824a
add tests comparing multivariate hypergeometric to univariate hyperge…
Michael-Howes 53b005f
add comments to mvhypergeometric sampler
Michael-Howes f2d6b4f
rename test file for multivariate hypergeometric
Michael-Howes 4bdaca4
append mvhypergeometric tests to runtest.jl
Michael-Howes d17c249
update tests
Michael-Howes 9e37f95
update project.toml
Michael-Howes 91d9da5
update tests to improve coverage
Michael-Howes 74240a6
Merge branch 'JuliaStats:master' into master
Michael-Howes 88878b4
rename tests
Michael-Howes e6b67d7
Update file names
Michael-Howes f7b61ff
Update documentation
Michael-Howes 58f992d
Improve test coverage
Michael-Howes 65c83f7
Update sampling function to remove redundant logic.
Michael-Howes dbeeb04
formate with julia vscode extension
Michael-Howes 30c7882
reformat
Michael-Howes fc068db
Merge branch 'master' into master
Michael-Howes 3f5985c
Remove nelements functions
Michael-Howes a5f8cc0
Remove nelements from test
Michael-Howes 17ee592
remove use of nelements
Michael-Howes a01a3e1
Add explicitly defined partype
Michael-Howes 4d2ed35
Remove abstractly typed fields
Michael-Howes c0a6cc0
Remove AbstractVector{Int{
Michael-Howes 5fb20ff
Remove where clause in type signature
Michael-Howes 4d58d63
Remove where clause in type signature
Michael-Howes b0c2386
Optimize @check_args
Michael-Howes 53f875a
Improve var computation
Michael-Howes 7042107
Change cov computation
Michael-Howes 947fed3
Optimize insupport
Michael-Howes d096ba5
Remove inbounds from _logpdf
Michael-Howes 20d68b8
Remove inbounds from sampling
Michael-Howes 4853cbc
Remove inbounds from sampling
Michael-Howes 84dfbb6
Move sampling method into main file
Michael-Howes cccd79f
Update formatting
Michael-Howes 6c20f7e
Add partype test
Michael-Howes b3496ab
Check support during compution of _logpdf
Michael-Howes 26e33f7
Update docstring
Michael-Howes afe12d1
Update support check in _logpdf
Michael-Howes 6b0ad57
Merge branch 'JuliaStats:master' into master
Michael-Howes File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,7 @@ MvNormalCanon | |
| MvLogitNormal | ||
| MvLogNormal | ||
| Dirichlet | ||
| MvHypergeometric | ||
| ``` | ||
|
|
||
| ## Addition Methods | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| """ | ||
| MvHypergeometric(m, n) | ||
|
|
||
| The [Multivariate hypergeometric distribution](https://en.wikipedia.org/wiki/Hypergeometric_distribution#Multivariate_hypergeometric_distribution) | ||
| generalizes the *hypergeometric distribution*. Consider ``n`` draws from a finite population containing ``k`` types of elements. Suppose that the population has size ``M`` and there are ``m_i`` elements of type ``i`` for ``i = 1, .., k`` with ``m_1+...m_k = M``. Let ``X = (X_1, ..., X_k)`` where ``X_i`` represents the number of elements of type ``i`` drawn, then the distribution of ``X`` is a multivariate hypergeometric distribution. Each sample of a multivariate hypergeometric distribution is a ``k``-dimensional integer vector that sums to ``n`` and satisfies ``0 \\le X_i \\le m_i``. | ||
|
|
||
|
|
||
| The probability mass function is given by | ||
|
|
||
| ```math | ||
| f(x; m, n) = {{{m_1 \\choose x_1}{m_2 \\choose x_2}\\cdots {m_k \\choose x_k}}\\over {N \\choose n}}, | ||
| \\quad x_1 + \\cdots + x_k = n, \\quad 0 \\le x_i \\le m_i | ||
| ``` | ||
|
|
||
| ```julia | ||
| MvHypergeometric(m, n) # Multivariate hypergeometric distribution for a population with | ||
| # m[i] elements of types i and n draws | ||
| ``` | ||
| """ | ||
| struct MvHypergeometric <: DiscreteMultivariateDistribution | ||
| m::Vector{Int} # number of elements of each type | ||
| n::Int # number of draws | ||
| function MvHypergeometric(m::Vector{Int}, n::Int; check_args::Bool=true) | ||
| @check_args( | ||
| MvHypergeometric, | ||
| (m, all(x -> x >= 0, m)), | ||
| zero(n) <= n <= sum(m), | ||
| ) | ||
| new(m, n) | ||
| end | ||
| end | ||
|
|
||
|
|
||
| # Parameters | ||
|
|
||
| ncategories(d::MvHypergeometric) = length(d.m) | ||
| length(d::MvHypergeometric) = ncategories(d) | ||
| ntrials(d::MvHypergeometric) = d.n | ||
|
|
||
| params(d::MvHypergeometric) = (d.m, d.n) | ||
Michael-Howes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| partype(::MvHypergeometric) = Int | ||
|
|
||
| # Statistics | ||
|
|
||
| mean(d::MvHypergeometric) = d.n .* d.m ./ sum(d.m) | ||
|
|
||
| function var(d::MvHypergeometric) | ||
| m = d.m | ||
| n = ntrials(d) | ||
| M = sum(m) | ||
| f = n * (M - n) / (M - 1) | ||
| v = let f = f | ||
| map(mi -> f * (mi / M) * ((M - mi) / M), m) | ||
| end | ||
| v | ||
| end | ||
|
|
||
| function cov(d::MvHypergeometric) | ||
| m = d.m | ||
| n = ntrials(d) | ||
| M = sum(m) | ||
| p = m / M | ||
| f = n * (M - n) / (M - 1) | ||
|
|
||
| C = -f * (p * p') | ||
| C[diagind(C)] .= f .* p .* (1 .- p) | ||
|
|
||
| C | ||
| end | ||
|
|
||
|
|
||
| # Evaluation | ||
| function insupport(d::MvHypergeometric, x::AbstractVector{<:Real}) | ||
| return length(x) == length(d) && (eltype(x) <: Integer || all(isinteger, x)) && all(((xi, mi),) -> zero(xi) <= xi <= mi, zip(x, d.m)) && sum(x) == ntrials(d) | ||
| end | ||
|
|
||
| function _logpdf(d::MvHypergeometric, x::AbstractVector{<:Real}) | ||
| m = d.m | ||
| M = sum(m) | ||
| n = ntrials(d) | ||
|
|
||
| s = -logabsbinomial(M, n)[1] | ||
| for i = 1:length(m) | ||
| xi = x[i] | ||
| mi = m[i] | ||
| ((typeof(xi) <: Integer || isinteger(xi)) && (zero(xi) <= xi <= mi)) || return -Float64(Inf) | ||
| s += logabsbinomial(mi, xi)[1] | ||
| n -= xi | ||
| end | ||
| (n == 0) || return -Float64(Inf) | ||
| return s | ||
| end | ||
|
|
||
| # Sampling is performed by sequentially sampling each entry from the | ||
| # hypergeometric distribution | ||
| function _rand!(rng::AbstractRNG, d::MvHypergeometric, x::AbstractVector{<:Real}) | ||
| k = length(d) | ||
| n = ntrials(d) | ||
| m = d.m | ||
| length(x) == k || throw(DimensionMismatch("Invalid argument dimension.")) | ||
|
|
||
| M = sum(m) | ||
| i = 0 | ||
| km1 = k - 1 | ||
|
|
||
| while i < km1 && n > 0 | ||
| i += 1 | ||
| mi = m[i] | ||
| # Sample from hypergeometric distribution. Element of type i are | ||
| # considered successes. All other elements are considered failures. | ||
| xi = rand(rng, Hypergeometric(mi, M - mi, n)) | ||
| x[i] = xi | ||
| # Remove elements of type i from population and group to be sampled. | ||
| n -= xi | ||
| M -= mi | ||
| end | ||
|
|
||
| if i == km1 | ||
| x[k] = n | ||
| else # n must have been zero. | ||
| z = zero(eltype(x)) | ||
| for j = i+1:k | ||
| x[j] = z | ||
| end | ||
| end | ||
|
|
||
| return x | ||
| end | ||
|
|
||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| # Tests for Multivariate Hypergeometric | ||
|
|
||
| using Distributions | ||
| using Test | ||
|
|
||
| @testset "Multivariate Hypergeometric" begin | ||
| @test_throws DomainError MvHypergeometric([5, 3, -2], 4) | ||
| @test_throws ArgumentError MvHypergeometric([5, 3], 10) | ||
|
|
||
| m = [5, 3, 2] | ||
| n = 4 | ||
| d = MvHypergeometric(m, n) | ||
| @test length(d) == 3 | ||
| @test d.n == n | ||
| @test d.m == m | ||
| @test ncategories(d) == length(m) | ||
| @test params(d) == (m, n) | ||
| @test partype(d) == Int | ||
|
|
||
| @test mean(d) ≈ [2.0, 1.2, 0.8] | ||
| @test var(d) ≈ [2 / 3, 56 / 100, 32 / 75] | ||
|
|
||
| covmat = cov(d) | ||
| @test covmat ≈ (8 / 3) .* [1/4 -3/20 -1/10; -3/20 21/100 -3/50; -1/10 -3/50 4/25] | ||
|
|
||
| @test insupport(d, [2, 1, 1]) | ||
| @test !insupport(d, [3, 2, 1]) | ||
| @test !insupport(d, [0, 0, 4]) | ||
|
|
||
|
|
||
| # random sampling | ||
| x = rand(d) | ||
| @test isa(x, Vector{Int}) | ||
| @test sum(x) == n | ||
| @test all(x .>= 0) | ||
| @test all(x .<= m) | ||
| @test insupport(d, x) | ||
|
|
||
| x = rand(d, 100) | ||
| @test isa(x, Matrix{Int}) | ||
| @test all(sum(x, dims=1) .== n) | ||
| @test all(x .>= 0) | ||
| @test all(x .<= m) | ||
| @test all(insupport(d, x)) | ||
|
|
||
| # random sampling with many catergories | ||
| m = [20, 2, 2, 2, 1, 1, 1] | ||
| n = 5 | ||
| d2 = MvHypergeometric(m, n) | ||
| x = rand(d2) | ||
| @test isa(x, Vector{Int}) | ||
| @test sum(x) == n | ||
| @test all(x .>= 0) | ||
| @test all(x .<= m) | ||
| @test insupport(d2, x) | ||
|
|
||
| # random sampling with a large category | ||
| m = [2, 1000] | ||
| n = 5 | ||
| d3 = MvHypergeometric(m, n) | ||
| x = rand(d3) | ||
| @test isa(x, Vector{Int}) | ||
| @test sum(x) == n | ||
| @test all(x .>= 0) | ||
| @test all(x .<= m) | ||
| @test insupport(d3, x) | ||
|
|
||
| # log pdf | ||
| x = [2, 1, 1] | ||
| @test pdf(d, x) ≈ 2 / 7 | ||
| @test logpdf(d, x) ≈ log(2 / 7) | ||
| @test logpdf(d, x) ≈ log(pdf(d, x)) | ||
| @test logpdf(d, [2.5, 0.5, 1]) == -Inf | ||
|
|
||
| x = rand(d, 100) | ||
| pv = pdf(d, x) | ||
| lp = logpdf(d, x) | ||
| for i in 1:size(x, 2) | ||
| @test pv[i] ≈ pdf(d, x[:, i]) | ||
| @test lp[i] ≈ logpdf(d, x[:, i]) | ||
| end | ||
|
|
||
| # test degenerate cases | ||
| d1 = MvHypergeometric([1], 1) | ||
| @test logpdf(d1, [1]) ≈ 0 | ||
| @test logpdf(d1, [0]) == -Inf | ||
| d2 = MvHypergeometric([2, 0], 1) | ||
| @test logpdf(d2, [1, 0]) ≈ 0 | ||
| @test logpdf(d2, [0, 1]) == -Inf | ||
|
|
||
| d3 = MvHypergeometric([5, 0, 0, 0], 3) | ||
| @test logpdf(d3, [3, 0, 0, 0]) ≈ 0 | ||
| @test logpdf(d3, [2, 1, 0, 0]) == -Inf | ||
| @test logpdf(d3, [2, 0, 0, 0]) == -Inf | ||
|
|
||
| # behavior with n = 0 | ||
| d0 = MvHypergeometric([5, 3, 2], 0) | ||
| @test logpdf(d0, [0, 0, 0]) ≈ 0 | ||
| @test logpdf(d0, [1, 0, 0]) == -Inf | ||
|
|
||
| @test rand(d0) == [0, 0, 0] | ||
| @test mean(d0) == [0.0, 0.0, 0.0] | ||
| @test var(d0) == [0.0, 0.0, 0.0] | ||
| @test insupport(d0, [0, 0, 0]) | ||
| @test !insupport(d0, [1, 0, 0]) | ||
| @test length(d0) == 3 | ||
|
|
||
| # compare with hypergeometric | ||
| ns = 3 | ||
| nf = 5 | ||
| n = 4 | ||
| dh1 = MvHypergeometric([ns, nf], n) | ||
| dh2 = Hypergeometric(ns, nf, n) | ||
|
|
||
| x = 2 | ||
| @test pdf(dh1, [x, n - x]) ≈ pdf(dh2, x) | ||
| x = 3 | ||
| @test pdf(dh1, [x, n - x]) ≈ pdf(dh2, x) | ||
|
|
||
| # comparing marginals to hypergeometric | ||
| m = [5, 3, 2] | ||
| n = 4 | ||
| d = MvHypergeometric(m, n) | ||
| dh = Hypergeometric(m[1], sum(m[2:end]), n) | ||
| x1 = 2 | ||
| @test pdf(dh, x1) ≈ sum([pdf(d, [x1, x2, n - x1 - x2]) for x2 in 0:m[2]]) | ||
|
|
||
| # comparing conditionals to hypergeometric | ||
| x1 = 2 | ||
| dh = Hypergeometric(m[2], m[3], n - x1) | ||
| q = sum([pdf(d, [x1, x2, n - x1 - x2]) for x2 in 0:m[2]]) | ||
| for x2 = 0:m[2] | ||
| @test pdf(dh, x2) ≈ pdf(d, [x1, x2, n - x1 - x2]) / q | ||
| end | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Julia docstrings should contain the type or function signature in the first line, indented by four spaces.
Generally, could you make this docstring consistent with docstrings of other existing distribution types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the function signature. The doc string is based on the multinomial distribution docstring. Is there a different distribution I should be following?
https://github.com/JuliaStats/Distributions.jl/blob/0421b1890c49417de993b255f6b177a1cb93fca3/src/multivariate/multinomial.jl