From 5dfcaa658a3161995931cc680a4a48a96cffd28d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Mar 2025 00:46:13 +0000 Subject: [PATCH 1/5] Implement `tracked_varnames` --- src/values_as_in_model.jl | 76 +++++++++++++++++++++++++++---- test/compiler.jl | 12 ----- test/model.jl | 42 ----------------- test/runtests.jl | 1 + test/values_as_in_model.jl | 93 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 161 insertions(+), 63 deletions(-) create mode 100644 test/values_as_in_model.jl diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index d3bfd697a..4c1db34c2 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -1,3 +1,12 @@ +""" + TrackedValue{T} + +A struct that wraps something on the right-hand side of `:=`. This is needed +because the DynamicPPL compiler actually converts `lhs := rhs` to `lhs ~ +TrackedValue(rhs)` (so that we can hit the `tilde_assume` method below). Having +the rhs wrapped in a TrackedValue makes sure that the logpdf of the rhs is not +computed (as it wouldn't make sense). +""" struct TrackedValue{T} value::T end @@ -24,17 +33,27 @@ struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext values::OrderedDict "whether to extract variables on the LHS of :=" include_colon_eq::Bool + "varnames to be tracked; `nothing` means track all varnames" + tracked_varnames::Union{Nothing,Array{<:VarName}} "child context" context::C end -function ValuesAsInModelContext(include_colon_eq, context::AbstractContext) - return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context) +function ValuesAsInModelContext( + include_colon_eq::Bool, + tracked_varnames::Union{Nothing,Array{<:VarName}}, + context::AbstractContext, +) + return ValuesAsInModelContext( + OrderedDict(), include_colon_eq, tracked_varnames, context + ) end NodeTrait(::ValuesAsInModelContext) = IsParent() childcontext(context::ValuesAsInModelContext) = context.context function setchildcontext(context::ValuesAsInModelContext, child) - return ValuesAsInModelContext(context.values, context.include_colon_eq, child) + return ValuesAsInModelContext( + context.values, context.include_colon_eq, context.tracked_varnames, child + ) end is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq @@ -63,29 +82,38 @@ end # `tilde_asssume` function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) - if is_tracked_value(right) + is_tracked_value_right = is_tracked_value(right) + if is_tracked_value_right value = right.value logp = zero(getlogp(vi)) else value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) end # Save the value. - push!(context, vn, value) - # Save the value. + if is_tracked_value_right || + isnothing(context.tracked_varnames) || + any(tracked_vn -> subsumes(tracked_vn, vn), context.tracked_varnames) + push!(context, vn, value) + end # Pass on. return value, logp, vi end function tilde_assume( rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi ) - if is_tracked_value(right) + is_tracked_value_right = is_tracked_value(right) + if is_tracked_value_right value = right.value logp = zero(getlogp(vi)) else value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) end # Save the value. - push!(context, vn, value) + if is_tracked_value_right || + isnothing(context.tracked_varnames) || + any(tracked_vn -> subsumes(tracked_vn, vn), context.tracked_varnames) + push!(context, vn, value) + end # Pass on. return value, logp, vi end @@ -167,9 +195,39 @@ function values_as_in_model( model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo, + tracked_varnames=tracked_varnames(model), context::AbstractContext=DefaultContext(), ) - context = ValuesAsInModelContext(include_colon_eq, context) + tracked_varnames = isnothing(tracked_varnames) ? nothing : collect(tracked_varnames) + context = ValuesAsInModelContext(include_colon_eq, tracked_varnames, context) evaluate!!(model, varinfo, context) return context.values end + +""" + tracked_varnames(model::Model) + +Returns a set of `VarName`s that the model should track. + +By default, this returns `nothing`, which means that all `VarName`s should be +tracked. + +If you want to track only a subset of `VarName`s, you can override this method +in your model definition: + +```julia +@model function mymodel() + x ~ Normal() + y ~ Normal(x, 1) +end + +DynamicPPL.tracked_varnames(::Model{typeof(mymodel)}) = [@varname(y)] +``` + +Then, when you sample from `mymodel()`, only the value of `y` will be tracked +(and not `x`). + +Note that quantities on the left-hand side of `:=` are always tracked, and will +ignore the varnames specified in this method. +""" +tracked_varnames(::Model) = nothing diff --git a/test/compiler.jl b/test/compiler.jl index 8d81c530a..2a0521fc6 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -728,18 +728,6 @@ module Issue537 end varinfo = VarInfo(model) @test haskey(varinfo, @varname(x)) @test !haskey(varinfo, @varname(y)) - - # While `values_as_in_model` should contain both `x` and `y`, if - # include_colon_eq is set to `true`. - values = values_as_in_model(model, true, deepcopy(varinfo)) - @test haskey(values, @varname(x)) - @test haskey(values, @varname(y)) - - # And if include_colon_eq is set to `false`, then `values` should - # only contain `x`. - values = values_as_in_model(model, false, deepcopy(varinfo)) - @test haskey(values, @varname(x)) - @test !haskey(values, @varname(y)) end end diff --git a/test/model.jl b/test/model.jl index 256ada0ad..37196c232 100644 --- a/test/model.jl +++ b/test/model.jl @@ -410,48 +410,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end - @testset "values_as_in_model" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # We can set the include_colon_eq arg to false because none of - # the demo models contain :=. The behaviour when - # include_colon_eq is true is tested in test/compiler.jl - realizations = values_as_in_model(model, false, varinfo) - # Ensure that all variables are found. - vns_found = collect(keys(realizations)) - @test vns ∩ vns_found == vns ∪ vns_found - # Ensure that the values are the same. - for vn in vns - @test realizations[vn] == varinfo[vn] - end - end - end - - @testset "Prefixing" begin - @model inner() = x ~ Normal() - - @model function outer_auto_prefix() - a ~ to_submodel(inner(), true) - b ~ to_submodel(inner(), true) - return nothing - end - @model function outer_manual_prefix() - a ~ to_submodel(prefix(inner(), :a), false) - b ~ to_submodel(prefix(inner(), :b), false) - return nothing - end - - for model in (outer_auto_prefix(), outer_manual_prefix()) - vi = VarInfo(model) - vns = Set(keys(values_as_in_model(model, false, vi))) - @test vns == Set([@varname(var"a.x"), @varname(var"b.x")]) - end - end - end - @testset "Erroneous model call" begin # Calling a model with the wrong arguments used to lead to infinite recursion, see # https://github.com/TuringLang/Turing.jl/issues/2182. This guards against it. diff --git a/test/runtests.jl b/test/runtests.jl index caddef5f9..a5fb3e031 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,6 +51,7 @@ include("test_util.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") + include("values_as_in_model.jl") include("sampler.jl") include("independence.jl") include("distribution_wrappers.jl") diff --git a/test/values_as_in_model.jl b/test/values_as_in_model.jl new file mode 100644 index 000000000..d21f980ee --- /dev/null +++ b/test/values_as_in_model.jl @@ -0,0 +1,93 @@ +@testset "values_as_in_model" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + example_values = DynamicPPL.TestUtils.rand_prior_true(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + # We can set the include_colon_eq arg to false because none of + # the demo models contain :=. The behaviour when + # include_colon_eq is true is tested in test/compiler.jl + realizations = values_as_in_model(model, false, varinfo) + # Ensure that all variables are found. + vns_found = collect(keys(realizations)) + @test vns ∩ vns_found == vns ∪ vns_found + # Ensure that the values are the same. + for vn in vns + @test realizations[vn] == varinfo[vn] + end + end + end + + @testset "support for :=" begin + @model function demo_tracked() + x ~ Normal() + y := 100 + x + return (; x, y) + end + @model function demo_tracked_submodel() + return vals ~ to_submodel(demo_tracked(), false) + end + + for model in [demo_tracked(), demo_tracked_submodel()] + values = values_as_in_model(model, true, VarInfo(model)) + @test haskey(values, @varname(x)) + @test haskey(values, @varname(y)) + + values = values_as_in_model(model, false, VarInfo(model)) + @test haskey(values, @varname(x)) + @test !haskey(values, @varname(y)) + end + end + + @testset "Prefixing" begin + @model inner() = x ~ Normal() + + @model function outer_auto_prefix() + a ~ to_submodel(inner(), true) + b ~ to_submodel(inner(), true) + return nothing + end + @model function outer_manual_prefix() + a ~ to_submodel(prefix(inner(), :a), false) + b ~ to_submodel(prefix(inner(), :b), false) + return nothing + end + + for model in (outer_auto_prefix(), outer_manual_prefix()) + vi = VarInfo(model) + vns = Set(keys(values_as_in_model(model, false, vi))) + @test vns == Set([@varname(var"a.x"), @varname(var"b.x")]) + end + end + + @testset "Track only specific varnames" begin + @model function track_specific() + x = Vector{Float64}(undef, 2) + # Include a vector x to test for correct subsumption behaviour + for i in eachindex(x) + x[i] ~ Normal() + end + y ~ Normal(x[1], 1) + return z := sum(x) + end + + model = track_specific() + vi = VarInfo(model) + + # Specify varnames to be tracked directly as an argument to `values_as_in_model` + values = values_as_in_model(model, true, vi, [@varname(x)]) + # Since x subsumes both x[1] and x[2], they should be included + @test haskey(values, @varname(x[1])) + @test haskey(values, @varname(x[2])) + @test !haskey(values, @varname(y)) + @test haskey(values, @varname(z)) # := is always included + + # Specify instead using `tracked_varnames` method + DynamicPPL.tracked_varnames(::Model{typeof(track_specific)}) = [@varname(y)] + values = values_as_in_model(model, true, vi) + @test !haskey(values, @varname(x[1])) + @test !haskey(values, @varname(x[2])) + @test haskey(values, @varname(y)) + @test haskey(values, @varname(z)) # := is always included + end +end From 5f37e6a8612743fa665dab7a45d812677252a560 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Mar 2025 10:41:54 +0000 Subject: [PATCH 2/5] Use model field instead --- src/DynamicPPL.jl | 1 + src/model.jl | 51 ++++++++++++++++++++++++++++++++------ src/values_as_in_model.jl | 31 ++--------------------- test/contexts.jl | 2 +- test/values_as_in_model.jl | 6 ++--- 5 files changed, 50 insertions(+), 41 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 50fe0edc7..e80762294 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -113,6 +113,7 @@ export AbstractVarInfo, pointwise_loglikelihoods, condition, decondition, + set_tracked_varnames, fix, unfix, predict, diff --git a/src/model.jl b/src/model.jl index 0fb18f463..9db8c9f6a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -4,12 +4,20 @@ args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx=DefaultContext() + tracked_varnames::Union{Nothing,Array{<:VarName}} end A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing arguments `missings`, and evaluation context of type `Ctx`. +`tracked_varnames` is an array of VarNames that should be tracked during sampling. During +model evaluation (with `DynamicPPL.evaluate!!`) all random variables are tracked; however, +at the end of each iteration of MCMC sampling, `DynamicPPL.values_as_in_model` is used to +extract the values of _only_ the tracked variables. This allows the user to control which +variables are ultimately stored in the chain. This field can be set using the +[`set_tracked_varnames`](@ref) function. + Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`. `context` is by default `DefaultContext()`. @@ -23,14 +31,17 @@ different arguments. # Examples ```julia +julia> f(x) = x + 1 # Dummy function +f (generic function with 1 method) + julia> Model(f, (x = 1.0, y = 2.0)) -Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple()) +Model{typeof(f), (:x, :y), (), (), Tuple{Float64, Float64}, Tuple{}, DefaultContext}(f, (x = 1.0, y = 2.0), NamedTuple(), DefaultContext(), nothing) julia> Model(f, (x = 1.0, y = 2.0), (x = 42,)) -Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) +Model{typeof(f), (:x, :y), (:x,), (), Tuple{Float64, Float64}, Tuple{Int64}, DefaultContext}(f, (x = 1.0, y = 2.0), (x = 42,), DefaultContext(), nothing) julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings -Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) +Model{typeof(f), (:x, :y), (:x,), (:y,), Tuple{Float64, Float64}, Tuple{Int64}, DefaultContext}(f, (x = 1.0, y = 2.0), (x = 42,), DefaultContext(), nothing) ``` """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: @@ -39,6 +50,7 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx + tracked_varnames::Union{Nothing,Array{<:VarName}} @doc """ Model{missings}(f, args::NamedTuple, defaults::NamedTuple) @@ -51,9 +63,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), + tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing, ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( - f, args, defaults, context + f, args, defaults, context, tracked_varnames ) end end @@ -71,6 +84,7 @@ model with different arguments. args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), + tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing, ) where {F,argnames,Targs,kwargnames,Tkwargs} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing @@ -78,15 +92,36 @@ model with different arguments. missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{$(missing_args..., missing_kwargs...)}( + f, args, defaults, context, tracked_varnames + )) end -function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) +function Model( + f, + args::NamedTuple, + context::AbstractContext=DefaultContext(), + tracked_varnames::Union{Nothing,Array{<:VarName}}=nothing; + kwargs..., +) + return Model(f, args, NamedTuple(kwargs), context, tracked_varnames) end function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) + return Model(model.f, model.args, model.defaults, context, model.tracked_varnames) +end + +""" + set_tracked_varnames(model::Model, varnames::Union{Nothing,Array{<:VarName}}) + +Return a new `Model` which only tracks a subset of variables during sampling. + +If `varnames` is `nothing`, then all variables will be tracked. Otherwise, only +the variables subsumed by `varnames` are tracked. For example, if `varnames = +[@varname(x)]`, then any variable `x`, `x[1]`, `x.a`, ... will be tracked. +""" +function set_tracked_varnames(model::Model, varnames::Union{Nothing,Array{<:VarName}}) + return Model(model.f, model.args, model.defaults, model.context, varnames) end """ diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4c1db34c2..42990f3e8 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -195,39 +195,12 @@ function values_as_in_model( model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo, - tracked_varnames=tracked_varnames(model), + tracked_varnames=model.tracked_varnames, context::AbstractContext=DefaultContext(), ) + @show tracked_varnames tracked_varnames = isnothing(tracked_varnames) ? nothing : collect(tracked_varnames) context = ValuesAsInModelContext(include_colon_eq, tracked_varnames, context) evaluate!!(model, varinfo, context) return context.values end - -""" - tracked_varnames(model::Model) - -Returns a set of `VarName`s that the model should track. - -By default, this returns `nothing`, which means that all `VarName`s should be -tracked. - -If you want to track only a subset of `VarName`s, you can override this method -in your model definition: - -```julia -@model function mymodel() - x ~ Normal() - y ~ Normal(x, 1) -end - -DynamicPPL.tracked_varnames(::Model{typeof(mymodel)}) = [@varname(y)] -``` - -Then, when you sample from `mymodel()`, only the value of `y` will be tracked -(and not `x`). - -Note that quantities on the left-hand side of `:=` are always tracked, and will -ignore the varnames specified in this method. -""" -tracked_varnames(::Model) = nothing diff --git a/test/contexts.jl b/test/contexts.jl index faa831cc1..e4f1b57c3 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -167,7 +167,7 @@ end ctx1 = PrefixContext{:a}(DefaultContext()) ctx2 = SamplingContext(ctx1) ctx3 = PrefixContext{:b}(ctx2) - ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) + ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, nothing, ctx3) vn_prefixed1 = prefix(ctx1, vn) vn_prefixed2 = prefix(ctx2, vn) vn_prefixed3 = prefix(ctx3, vn) diff --git a/test/values_as_in_model.jl b/test/values_as_in_model.jl index d21f980ee..6cb4731f8 100644 --- a/test/values_as_in_model.jl +++ b/test/values_as_in_model.jl @@ -82,9 +82,9 @@ @test !haskey(values, @varname(y)) @test haskey(values, @varname(z)) # := is always included - # Specify instead using `tracked_varnames` method - DynamicPPL.tracked_varnames(::Model{typeof(track_specific)}) = [@varname(y)] - values = values_as_in_model(model, true, vi) + # Specify instead using `set_tracked_varnames` method + model2 = DynamicPPL.set_tracked_varnames(model, [@varname(y)]) + values = values_as_in_model(model2, true, vi) @test !haskey(values, @varname(x[1])) @test !haskey(values, @varname(x[2])) @test haskey(values, @varname(y)) From 06d6e4e94d122fa8158d05982f0c2c46dcc8e8e8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Mar 2025 10:54:04 +0000 Subject: [PATCH 3/5] Add documentation --- docs/src/api.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 9c8249c97..a4895edff 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -201,6 +201,13 @@ Safe extraction of values from a given [`AbstractVarInfo`](@ref) as they are see values_as_in_model ``` +`values_as_in_model` also uses the `tracked_varnames` field on a [`Model`](@ref) to determine which variables are extracted. +To change the value of this field, you can use [`set_tracked_varnames`](@ref). + +```@docs +set_tracked_varnames +``` + ```@docs NamedDist ``` From d7ef01ded5405f3eea6b5eb1e7dfab765c2c38ee Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Mar 2025 11:03:45 +0000 Subject: [PATCH 4/5] Add HISTORY.md entry for set_tracked_varnames --- HISTORY.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index b36003965..b414b5650 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,30 @@ # DynamicPPL Changelog +## 0.36.0 + +**Breaking changes** + +### `set_tracked_varnames` + +There is now a new method `set_tracked_varnames(::Model, ::Union{Nothing,Array{<:VarName}})`, which allows you to specify the variables that are collected when `values_as_in_model` is run. +Internally in DynamicPPL this does not have much impact. +However, Turing.jl uses `values_as_in_model` to collect the variable names and values during sampling, and so this method will effectively allow you to control which variables are ultimately stored in a chain. + +Example usage: + +```julia +@model function f() + x ~ Normal() + y ~ Normal() + return x, y +end + +model = f() +model = set_tracked_varnames(model, [@varname(y)]) +``` + +If you then sample from `model`, only the value of `y` will be stored in the chain, and not `x`. + ## 0.35.0 **Breaking changes** From 67b0e3f8b44bfbd5e465b4be3a7d874dbe028f5a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 18 Mar 2025 11:58:03 +0000 Subject: [PATCH 5/5] Remove a stray @show --- src/values_as_in_model.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 42990f3e8..615029f41 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -198,7 +198,6 @@ function values_as_in_model( tracked_varnames=model.tracked_varnames, context::AbstractContext=DefaultContext(), ) - @show tracked_varnames tracked_varnames = isnothing(tracked_varnames) ? nothing : collect(tracked_varnames) context = ValuesAsInModelContext(include_colon_eq, tracked_varnames, context) evaluate!!(model, varinfo, context)