@@ -59,8 +59,13 @@ for type stability of `tilde_assume`.
5959# Fields
6060$(FIELDS)
6161"""
62- struct GibbsContext{VNs,GVI<: Ref{<:AbstractVarInfo} ,Ctx<: DynamicPPL.AbstractContext } < :
63- DynamicPPL. AbstractContext
62+ struct GibbsContext{
63+ VNs<: Tuple{Vararg{VarName}} ,GVI<: Ref{<:AbstractVarInfo} ,Ctx<: DynamicPPL.AbstractContext
64+ } <: DynamicPPL.AbstractContext
65+ """
66+ the VarNames being sampled
67+ """
68+ target_varnames:: VNs
6469 """
6570 a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both
6671 those fixed and those being sampled. We use a `Ref` because this field may need to be
@@ -72,26 +77,14 @@ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractCont
7277 """
7378 context:: Ctx
7479
75- function GibbsContext {VNs} (global_varinfo, context) where {VNs}
76- if ! can_be_wrapped (context)
77- error (" GibbsContext can only wrap a leaf context, not a $(context) ." )
78- end
79- return new {VNs,typeof(global_varinfo),typeof(context)} (global_varinfo, context)
80- end
81-
8280 function GibbsContext (target_varnames, global_varinfo, context)
8381 if ! can_be_wrapped (context)
8482 error (" GibbsContext can only wrap a leaf context, not a $(context) ." )
8583 end
86- if any (vn -> DynamicPPL. getoptic (vn) != identity, target_varnames)
87- msg =
88- " All Gibbs target variables must have identity lenses. " *
89- " For example, you can't have `@varname(x.a[1])` as a target variable, " *
90- " only `@varname(x)`."
91- error (msg)
92- end
93- vn_sym = tuple (unique ((DynamicPPL. getsym (vn) for vn in target_varnames))... )
94- return new {vn_sym,typeof(global_varinfo),typeof(context)} (global_varinfo, context)
84+ target_varnames = tuple (target_varnames... ) # Allow vectors.
85+ return new {typeof(target_varnames),typeof(global_varinfo),typeof(context)} (
86+ target_varnames, global_varinfo, context
87+ )
9588 end
9689end
9790
10194
10295DynamicPPL. NodeTrait (:: GibbsContext ) = DynamicPPL. IsParent ()
10396DynamicPPL. childcontext (context:: GibbsContext ) = context. context
104- function DynamicPPL. setchildcontext (context:: GibbsContext{VNs} , childcontext) where {VNs}
105- return GibbsContext {VNs} (Ref (context. global_varinfo[]), childcontext)
97+ function DynamicPPL. setchildcontext (context:: GibbsContext , childcontext)
98+ return GibbsContext (
99+ context. target_varnames, Ref (context. global_varinfo[]), childcontext
100+ )
106101end
107102
108103get_global_varinfo (context:: GibbsContext ) = context. global_varinfo[]
@@ -134,7 +129,9 @@ function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa
134129 return map (Base. Fix1 (get_conditioned_gibbs, context), vns)
135130end
136131
137- is_target_varname (:: GibbsContext{VNs} , :: VarName{sym} ) where {VNs,sym} = sym in VNs
132+ function is_target_varname (ctx:: GibbsContext , vn:: VarName )
133+ return any (Base. Fix2 (subsumes, vn), ctx. target_varnames)
134+ end
138135
139136function is_target_varname (context:: GibbsContext , vns:: AbstractArray{<:VarName} )
140137 num_target = count (Iterators. map (Base. Fix1 (is_target_varname, context), vns))
150147# Tilde pipeline
151148function DynamicPPL. tilde_assume (context:: GibbsContext , right, vn, vi)
152149 child_context = DynamicPPL. childcontext (context)
150+
151+ # Note that `child_context` may contain `PrefixContext`s -- in which case
152+ # we need to make sure that vn is appropriately prefixed before we handle
153+ # the `GibbsContext` behaviour below. For example, consider the following:
154+ # @model inner() = x ~ Normal()
155+ # @model outer() = a ~ to_submodel(inner())
156+ # If we run this with `Gibbs(@varname(a.x) => MH())`, then when we are
157+ # executing the submodel, the `context` will contain the `@varname(a.x)`
158+ # variable; `child_context` will contain `PrefixContext(@varname(a))`; and
159+ # `vn` will just be `@varname(x)`. If we just simply run
160+ # `is_target_varname(context, vn)`, it will return false, and everything
161+ # will be messed up.
162+ # TODO (penelopeysm): This 'problem' could be solved if we made GibbsContext a
163+ # leaf context and wrapped the PrefixContext _above_ the GibbsContext, so
164+ # that the prefixing would be handled by tilde_assume(::PrefixContext, ...)
165+ # _before_ we hit this method.
166+ # In the current state of GibbsContext, doing this would require
167+ # special-casing the way PrefixContext is used to wrap the leaf context.
168+ # This is very inconvenient because PrefixContext's behaviour is defined in
169+ # DynamicPPL, and we would basically have to create a new method in Turing
170+ # and override it for GibbsContext. Indeed, a better way to do this would
171+ # be to make GibbsContext a leaf context. In this case, we would be able to
172+ # rely on the existing behaviour of DynamicPPL.make_evaluate_args_and_kwargs
173+ # to correctly wrap the PrefixContext around the GibbsContext. This is very
174+ # tricky to correctly do now, but once we remove the other leaf contexts
175+ # (i.e. PriorContext and LikelihoodContext), we should be able to do this.
176+ # This is already implemented in
177+ # https://github.com/TuringLang/DynamicPPL.jl/pull/885/ but not yet
178+ # released. Exciting!
179+ vn, child_context = DynamicPPL. prefix_and_strip_contexts (child_context, vn)
180+
153181 return if is_target_varname (context, vn)
154182 # Fall back to the default behavior.
155183 DynamicPPL. tilde_assume (child_context, right, vn, vi)
@@ -182,6 +210,8 @@ function DynamicPPL.tilde_assume(
182210)
183211 # See comment in the above, rng-less version of this method for an explanation.
184212 child_context = DynamicPPL. childcontext (context)
213+ vn, child_context = DynamicPPL. prefix_and_strip_contexts (child_context, vn)
214+
185215 return if is_target_varname (context, vn)
186216 DynamicPPL. tilde_assume (rng, child_context, sampler, right, vn, vi)
187217 elseif has_conditioned_gibbs (context, vn)
0 commit comments