Skip to content

Commit 8c5f782

Browse files
committed
Enable non-identity VarNames in Gibbs
Closes #2403
1 parent ce0e616 commit 8c5f782

File tree

1 file changed

+51
-21
lines changed

1 file changed

+51
-21
lines changed

src/mcmc/gibbs.jl

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,13 @@ for type stability of `tilde_assume`.
5959
# Fields
6060
$(FIELDS)
6161
"""
62-
struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <:
63-
DynamicPPL.AbstractContext
62+
struct GibbsContext{
63+
VNs<:Tuple{Vararg{VarName}},GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext
64+
} <: DynamicPPL.AbstractContext
65+
"""
66+
the VarNames being sampled
67+
"""
68+
target_varnames::VNs
6469
"""
6570
a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both
6671
those fixed and those being sampled. We use a `Ref` because this field may need to be
@@ -72,26 +77,14 @@ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractCont
7277
"""
7378
context::Ctx
7479

75-
function GibbsContext{VNs}(global_varinfo, context) where {VNs}
76-
if !can_be_wrapped(context)
77-
error("GibbsContext can only wrap a leaf context, not a $(context).")
78-
end
79-
return new{VNs,typeof(global_varinfo),typeof(context)}(global_varinfo, context)
80-
end
81-
8280
function GibbsContext(target_varnames, global_varinfo, context)
8381
if !can_be_wrapped(context)
8482
error("GibbsContext can only wrap a leaf context, not a $(context).")
8583
end
86-
if any(vn -> DynamicPPL.getoptic(vn) != identity, target_varnames)
87-
msg =
88-
"All Gibbs target variables must have identity lenses. " *
89-
"For example, you can't have `@varname(x.a[1])` as a target variable, " *
90-
"only `@varname(x)`."
91-
error(msg)
92-
end
93-
vn_sym = tuple(unique((DynamicPPL.getsym(vn) for vn in target_varnames))...)
94-
return new{vn_sym,typeof(global_varinfo),typeof(context)}(global_varinfo, context)
84+
target_varnames = tuple(target_varnames...) # Allow vectors.
85+
return new{typeof(target_varnames),typeof(global_varinfo),typeof(context)}(
86+
target_varnames, global_varinfo, context
87+
)
9588
end
9689
end
9790

@@ -101,8 +94,10 @@ end
10194

10295
DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent()
10396
DynamicPPL.childcontext(context::GibbsContext) = context.context
104-
function DynamicPPL.setchildcontext(context::GibbsContext{VNs}, childcontext) where {VNs}
105-
return GibbsContext{VNs}(Ref(context.global_varinfo[]), childcontext)
97+
function DynamicPPL.setchildcontext(context::GibbsContext, childcontext)
98+
return GibbsContext(
99+
context.target_varnames, Ref(context.global_varinfo[]), childcontext
100+
)
106101
end
107102

108103
get_global_varinfo(context::GibbsContext) = context.global_varinfo[]
@@ -134,7 +129,9 @@ function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa
134129
return map(Base.Fix1(get_conditioned_gibbs, context), vns)
135130
end
136131

137-
is_target_varname(::GibbsContext{VNs}, ::VarName{sym}) where {VNs,sym} = sym in VNs
132+
function is_target_varname(ctx::GibbsContext, vn::VarName)
133+
return any(Base.Fix2(subsumes, vn), ctx.target_varnames)
134+
end
138135

139136
function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName})
140137
num_target = count(Iterators.map(Base.Fix1(is_target_varname, context), vns))
@@ -150,6 +147,37 @@ end
150147
# Tilde pipeline
151148
function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
152149
child_context = DynamicPPL.childcontext(context)
150+
151+
# Note that `child_context` may contain `PrefixContext`s -- in which case
152+
# we need to make sure that vn is appropriately prefixed before we handle
153+
# the `GibbsContext` behaviour below. For example, consider the following:
154+
# @model inner() = x ~ Normal()
155+
# @model outer() = a ~ to_submodel(inner())
156+
# If we run this with `Gibbs(@varname(a.x) => MH())`, then when we are
157+
# executing the submodel, the `context` will contain the `@varname(a.x)`
158+
# variable; `child_context` will contain `PrefixContext(@varname(a))`; and
159+
# `vn` will just be `@varname(x)`. If we just simply run
160+
# `is_target_varname(context, vn)`, it will return false, and everything
161+
# will be messed up.
162+
# TODO(penelopeysm): This 'problem' could be solved if we made GibbsContext a
163+
# leaf context and wrapped the PrefixContext _above_ the GibbsContext, so
164+
# that the prefixing would be handled by tilde_assume(::PrefixContext, ...)
165+
# _before_ we hit this method.
166+
# In the current state of GibbsContext, doing this would require
167+
# special-casing the way PrefixContext is used to wrap the leaf context.
168+
# This is very inconvenient because PrefixContext's behaviour is defined in
169+
# DynamicPPL, and we would basically have to create a new method in Turing
170+
# and override it for GibbsContext. Indeed, a better way to do this would
171+
# be to make GibbsContext a leaf context. In this case, we would be able to
172+
# rely on the existing behaviour of DynamicPPL.make_evaluate_args_and_kwargs
173+
# to correctly wrap the PrefixContext around the GibbsContext. This is very
174+
# tricky to correctly do now, but once we remove the other leaf contexts
175+
# (i.e. PriorContext and LikelihoodContext), we should be able to do this.
176+
# This is already implemented in
177+
# https://github.com/TuringLang/DynamicPPL.jl/pull/885/ but not yet
178+
# released. Exciting!
179+
vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn)
180+
153181
return if is_target_varname(context, vn)
154182
# Fall back to the default behavior.
155183
DynamicPPL.tilde_assume(child_context, right, vn, vi)
@@ -182,6 +210,8 @@ function DynamicPPL.tilde_assume(
182210
)
183211
# See comment in the above, rng-less version of this method for an explanation.
184212
child_context = DynamicPPL.childcontext(context)
213+
vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn)
214+
185215
return if is_target_varname(context, vn)
186216
DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi)
187217
elseif has_conditioned_gibbs(context, vn)

0 commit comments

Comments
 (0)