From 8544b0ff9754dd16eecd1562d55f1b452408cccd Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Feb 2025 16:11:54 +0000 Subject: [PATCH 1/3] Change dot_tilde to use filldist rather than a loop --- HISTORY.md | 8 ++++---- Project.toml | 1 + src/DynamicPPL.jl | 1 + src/compiler.jl | 7 ++----- src/pointwise_logdensities.jl | 9 ++++----- src/test_utils/models.jl | 12 ++++++------ test/varinfo.jl | 9 ++------- 7 files changed, 20 insertions(+), 27 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 748b9d506..cd539adac 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -45,7 +45,9 @@ x ~ product_distribution(Normal.(y)) x ~ MvNormal(fill(0.0, 2), I) ``` -This is often more performant as well. Note that using `~` rather than `.~` does change the internal storage format a bit: With `.~` `x[i]` are stored as separate variables, with `~` as a single multivariate variable `x`. In most cases this does not change anything for the user, but if it does cause issues, e.g. if you are dealing with `VarInfo` objects directly and need to keep the old behavior, you can always expand into a loop, such as +This is often more performant as well. + +The new implementation of `x .~ ...` is just a short-hand for `x ~ filldist(...)`, which means that `x` will be seen as a single multivariate variable. In most cases this does not change anything for the user, with the one notable exception being `pointwise_loglikelihoods`, which previously treated `.~` assignments as assigning multiple univariate variables. If you _do_ want a variable to be seen as an array of univariate variables rather than a single multivariate variable, you can always expand into a loop, such as ```julia dists = Normal.(y) @@ -54,7 +56,7 @@ for i in 1:length(dists) end ``` -Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example, +Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must always be replaced with a loop. For example, ```julia x = Array{Float64,3}(undef, 2, 3, 4) @@ -70,8 +72,6 @@ for i in 1:3, j in 1:4 end ``` -This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side. - ### Remove indexing by samplers This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular, diff --git a/Project.toml b/Project.toml index a9463a821..07fdf3a8e 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 50fe0edc7..f1542e0a9 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -5,6 +5,7 @@ using AbstractPPL using Bijectors using Compat using Distributions +using DistributionsAD: filldist using OrderedCollections: OrderedCollections, OrderedDict using AbstractMCMC: AbstractMCMC diff --git a/src/compiler.jl b/src/compiler.jl index 8bde5e784..27592aecd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -514,13 +514,10 @@ end Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - @gensym dist left_axes idx + @gensym dist return quote $dist = DynamicPPL.check_dot_tilde_rhs($right) - $left_axes = axes($left) - for $idx in Iterators.product($left_axes...) - $left[$idx...] ~ $dist - end + $left ~ DynamicPPL.filldist($dist, Base.size($left)...) end end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index cb9ea4894..9ed229c8e 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -157,8 +157,8 @@ y .~ Normal(μ, σ) y ~ MvNormal(fill(μ, n), σ^2 * I) ``` -In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -while in (3) `y` will be treated as a _single_ n-dimensional observation. +In (1) `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, +while in (2) and (3) `y` will be treated as a _single_ n-dimensional observation. This is important to keep in mind, in particular if the computation is used for downstream computations. @@ -216,8 +216,7 @@ OrderedDict{VarName, Matrix{Float64}} with 6 entries: ``` ## Broadcasting -Note that `x .~ Dist()` will treat `x` as a collection of -_independent_ observations rather than as a single observation. +Note that `x .~ Dist()` will treat `x` as a single multivariate observation. ```jldoctest; setup = :(using Distributions) julia> @model function demo(x) @@ -226,7 +225,7 @@ julia> @model function demo(x) julia> m = demo([1.0, ]); -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x)]) -1.4189385332046727 julia> m = demo([1.0; 1.0]); diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index e29614982..14b264c98 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -208,7 +208,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe)}) - return [@varname(s[1]), @varname(s[2]), @varname(m)] + return [@varname(s), @varname(m)] end @model function demo_assume_index_observe( @@ -293,7 +293,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) - return [@varname(s[1]), @varname(s[2]), @varname(m)] + return [@varname(s), @varname(m)] end # Using vector of `length` 1 here so the posterior of `m` is the same @@ -374,7 +374,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m)] + return [@varname(s), @varname(m)] end @model function demo_assume_observe_literal() @@ -458,7 +458,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m)] + return [@varname(s), @varname(m)] end @model function _likelihood_multivariate_observe(s, m, x) @@ -492,7 +492,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) - return [@varname(s[1]), @varname(s[2]), @varname(m)] + return [@varname(s), @varname(m)] end @model function demo_dot_assume_observe_matrix_index( @@ -521,7 +521,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)}) - return [@varname(s[1]), @varname(s[2]), @varname(m)] + return [@varname(s), @varname(m)] end @model function demo_assume_matrix_observe_matrix_index( diff --git a/test/varinfo.jl b/test/varinfo.jl index d689a1bf4..74012ecb1 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -426,12 +426,7 @@ end # Transform only one variable all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns) for vn in [ - @varname(s), - @varname(m), - @varname(x), - @varname(y), - @varname(x[2]), - @varname(y[2]) + @varname(s), @varname(m), @varname(x), @varname(y), @varname(x), @varname(y[2]) ] target_vns = filter(x -> subsumes(vn, x), all_vns) other_vns = filter(x -> !subsumes(vn, x), all_vns) @@ -874,7 +869,7 @@ end varinfo2 = last( DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) ) - for vn in [@varname(x), @varname(y[1])] + for vn in [@varname(x), @varname(y)] @test DynamicPPL.istrans(varinfo2, vn) end end From 15a0a88f3800d0c9f511cce8badda246b420796b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Feb 2025 16:55:30 +0000 Subject: [PATCH 2/3] Fix pointwise_logdensities doctest --- src/pointwise_logdensities.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 9ed229c8e..83f0bb739 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -223,15 +223,10 @@ julia> @model function demo(x) x .~ Normal() end; -julia> m = demo([1.0, ]); - -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x)]) --1.4189385332046727 - julia> m = demo([1.0; 1.0]); -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -(-1.4189385332046727, -1.4189385332046727) +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x)]) +-2.8378770664093453 ``` """ From 67d3629597ddf922633f084454dd220c7a7a0886 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Feb 2025 11:15:02 +0000 Subject: [PATCH 3/3] Switch from filldist to product_distribution --- Project.toml | 1 - src/DynamicPPL.jl | 1 - src/compiler.jl | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 07fdf3a8e..a9463a821 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f1542e0a9..50fe0edc7 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -5,7 +5,6 @@ using AbstractPPL using Bijectors using Compat using Distributions -using DistributionsAD: filldist using OrderedCollections: OrderedCollections, OrderedDict using AbstractMCMC: AbstractMCMC diff --git a/src/compiler.jl b/src/compiler.jl index 27592aecd..7bdefdc8f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -517,7 +517,7 @@ function generate_dot_tilde(left, right) @gensym dist return quote $dist = DynamicPPL.check_dot_tilde_rhs($right) - $left ~ DynamicPPL.filldist($dist, Base.size($left)...) + $left ~ DynamicPPL.product_distribution(Base.fill($dist, Base.size($left)...)) end end