@@ -59,8 +59,13 @@ for type stability of `tilde_assume`.
59
59
# Fields
60
60
$(FIELDS)
61
61
"""
62
- struct GibbsContext{VNs,GVI<: Ref{<:AbstractVarInfo} ,Ctx<: DynamicPPL.AbstractContext } < :
63
- DynamicPPL. AbstractContext
62
+ struct GibbsContext{
63
+ VNs<: Tuple{Vararg{VarName}} ,GVI<: Ref{<:AbstractVarInfo} ,Ctx<: DynamicPPL.AbstractContext
64
+ } <: DynamicPPL.AbstractContext
65
+ """
66
+ the VarNames being sampled
67
+ """
68
+ target_varnames:: VNs
64
69
"""
65
70
a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both
66
71
those fixed and those being sampled. We use a `Ref` because this field may need to be
@@ -72,26 +77,14 @@ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractCont
72
77
"""
73
78
context:: Ctx
74
79
75
- function GibbsContext {VNs} (global_varinfo, context) where {VNs}
76
- if ! can_be_wrapped (context)
77
- error (" GibbsContext can only wrap a leaf context, not a $(context) ." )
78
- end
79
- return new {VNs,typeof(global_varinfo),typeof(context)} (global_varinfo, context)
80
- end
81
-
82
80
function GibbsContext (target_varnames, global_varinfo, context)
83
81
if ! can_be_wrapped (context)
84
82
error (" GibbsContext can only wrap a leaf context, not a $(context) ." )
85
83
end
86
- if any (vn -> DynamicPPL. getoptic (vn) != identity, target_varnames)
87
- msg =
88
- " All Gibbs target variables must have identity lenses. " *
89
- " For example, you can't have `@varname(x.a[1])` as a target variable, " *
90
- " only `@varname(x)`."
91
- error (msg)
92
- end
93
- vn_sym = tuple (unique ((DynamicPPL. getsym (vn) for vn in target_varnames))... )
94
- return new {vn_sym,typeof(global_varinfo),typeof(context)} (global_varinfo, context)
84
+ target_varnames = tuple (target_varnames... ) # Allow vectors.
85
+ return new {typeof(target_varnames),typeof(global_varinfo),typeof(context)} (
86
+ target_varnames, global_varinfo, context
87
+ )
95
88
end
96
89
end
97
90
101
94
102
95
DynamicPPL. NodeTrait (:: GibbsContext ) = DynamicPPL. IsParent ()
103
96
DynamicPPL. childcontext (context:: GibbsContext ) = context. context
104
- function DynamicPPL. setchildcontext (context:: GibbsContext{VNs} , childcontext) where {VNs}
105
- return GibbsContext {VNs} (Ref (context. global_varinfo[]), childcontext)
97
+ function DynamicPPL. setchildcontext (context:: GibbsContext , childcontext)
98
+ return GibbsContext (
99
+ context. target_varnames, Ref (context. global_varinfo[]), childcontext
100
+ )
106
101
end
107
102
108
103
get_global_varinfo (context:: GibbsContext ) = context. global_varinfo[]
@@ -134,7 +129,9 @@ function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa
134
129
return map (Base. Fix1 (get_conditioned_gibbs, context), vns)
135
130
end
136
131
137
- is_target_varname (:: GibbsContext{VNs} , :: VarName{sym} ) where {VNs,sym} = sym in VNs
132
+ function is_target_varname (ctx:: GibbsContext , vn:: VarName )
133
+ return any (Base. Fix2 (subsumes, vn), ctx. target_varnames)
134
+ end
138
135
139
136
function is_target_varname (context:: GibbsContext , vns:: AbstractArray{<:VarName} )
140
137
num_target = count (Iterators. map (Base. Fix1 (is_target_varname, context), vns))
150
147
# Tilde pipeline
151
148
function DynamicPPL. tilde_assume (context:: GibbsContext , right, vn, vi)
152
149
child_context = DynamicPPL. childcontext (context)
150
+
151
+ # Note that `child_context` may contain `PrefixContext`s -- in which case
152
+ # we need to make sure that vn is appropriately prefixed before we handle
153
+ # the `GibbsContext` behaviour below. For example, consider the following:
154
+ # @model inner() = x ~ Normal()
155
+ # @model outer() = a ~ to_submodel(inner())
156
+ # If we run this with `Gibbs(@varname(a.x) => MH())`, then when we are
157
+ # executing the submodel, the `context` will contain the `@varname(a.x)`
158
+ # variable; `child_context` will contain `PrefixContext(@varname(a))`; and
159
+ # `vn` will just be `@varname(x)`. If we just simply run
160
+ # `is_target_varname(context, vn)`, it will return false, and everything
161
+ # will be messed up.
162
+ # TODO (penelopeysm): This 'problem' could be solved if we made GibbsContext a
163
+ # leaf context and wrapped the PrefixContext _above_ the GibbsContext, so
164
+ # that the prefixing would be handled by tilde_assume(::PrefixContext, ...)
165
+ # _before_ we hit this method.
166
+ # In the current state of GibbsContext, doing this would require
167
+ # special-casing the way PrefixContext is used to wrap the leaf context.
168
+ # This is very inconvenient because PrefixContext's behaviour is defined in
169
+ # DynamicPPL, and we would basically have to create a new method in Turing
170
+ # and override it for GibbsContext. Indeed, a better way to do this would
171
+ # be to make GibbsContext a leaf context. In this case, we would be able to
172
+ # rely on the existing behaviour of DynamicPPL.make_evaluate_args_and_kwargs
173
+ # to correctly wrap the PrefixContext around the GibbsContext. This is very
174
+ # tricky to correctly do now, but once we remove the other leaf contexts
175
+ # (i.e. PriorContext and LikelihoodContext), we should be able to do this.
176
+ # This is already implemented in
177
+ # https://github.com/TuringLang/DynamicPPL.jl/pull/885/ but not yet
178
+ # released. Exciting!
179
+ vn, child_context = DynamicPPL. prefix_and_strip_contexts (child_context, vn)
180
+
153
181
return if is_target_varname (context, vn)
154
182
# Fall back to the default behavior.
155
183
DynamicPPL. tilde_assume (child_context, right, vn, vi)
@@ -182,6 +210,8 @@ function DynamicPPL.tilde_assume(
182
210
)
183
211
# See comment in the above, rng-less version of this method for an explanation.
184
212
child_context = DynamicPPL. childcontext (context)
213
+ vn, child_context = DynamicPPL. prefix_and_strip_contexts (child_context, vn)
214
+
185
215
return if is_target_varname (context, vn)
186
216
DynamicPPL. tilde_assume (rng, child_context, sampler, right, vn, vi)
187
217
elseif has_conditioned_gibbs (context, vn)
0 commit comments