@@ -179,33 +179,33 @@ get_varinfo(s::MHState) = s.varinfo
179179# ####################
180180
181181"""
182- set_namedtuple!(vi::VarInfo, nt::NamedTuple)
182+ OldLogDensityFunction
183183
184- Places the values of a `NamedTuple` into the relevant places of a `VarInfo`.
184+ This is a clone of pre-0.39 DynamicPPL.LogDensityFunction. It is needed for MH because MH
185+ doesn't actually obey the LogDensityProblems.jl interface: it evaluates
186+ 'LogDensityFunctions' with a NamedTuple(!!)
187+
188+ This means that we can't _really_ use DynamicPPL's LogDensityFunction, since that only
189+ promises to obey the interface of being called with a vector.
190+
191+ In particular, because `set_namedtuple!` acts on a VarInfo, we need to store the VarInfo
192+ inside this struct (which DynamicPPL's LogDensityFunction no longer does).
193+
194+ This SHOULD really be refactored to remove this requirement.
185195"""
186- function set_namedtuple! (vi:: DynamicPPL.VarInfoOrThreadSafeVarInfo , nt:: NamedTuple )
187- for (n, vals) in pairs (nt)
188- vns = vi. metadata[n]. vns
189- if vals isa AbstractVector
190- vals = unvectorize (vals)
191- end
192- if length (vns) == 1
193- # Only one variable, assign the values to it
194- DynamicPPL. setindex! (vi, vals, vns[1 ])
195- else
196- # Spread the values across the variables
197- length (vns) == length (vals) || error (" Unequal number of variables and values" )
198- for (vn, val) in zip (vns, vals)
199- DynamicPPL. setindex! (vi, val, vn)
200- end
201- end
202- end
196+ struct OldLogDensityFunction{M<: DynamicPPL.Model ,V<: DynamicPPL.AbstractVarInfo }
197+ model:: M
198+ varinfo:: V
199+ end
200+ function (f:: OldLogDensityFunction )(x:: AbstractVector )
201+ vi = DynamicPPL. unflatten (f. varinfo, x)
202+ _, vi = DynamicPPL. evaluate!! (f. model, vi)
203+ return DynamicPPL. getlogjoint_internal (vi)
203204end
204-
205205# NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems
206206# interface in that it gets evaluated with a NamedTuple. Hence we need this
207207# method just to deal with MH.
208- function LogDensityProblems . logdensity (f:: LogDensityFunction , x:: NamedTuple )
208+ function (f:: OldLogDensityFunction )( x:: NamedTuple )
209209 vi = deepcopy (f. varinfo)
210210 # Note that the NamedTuple `x` does NOT conform to the structure required for
211211 # `InitFromParams`. In particular, for models that look like this:
@@ -223,8 +223,31 @@ function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple)
223223 set_namedtuple! (vi, x)
224224 # Update log probability.
225225 _, vi_new = DynamicPPL. evaluate!! (f. model, vi)
226- lj = f. getlogdensity (vi_new)
227- return lj
226+ return DynamicPPL. getlogjoint_internal (vi_new)
227+ end
228+
229+ """
230+ set_namedtuple!(vi::VarInfo, nt::NamedTuple)
231+
232+ Places the values of a `NamedTuple` into the relevant places of a `VarInfo`.
233+ """
234+ function set_namedtuple! (vi:: DynamicPPL.VarInfoOrThreadSafeVarInfo , nt:: NamedTuple )
235+ for (n, vals) in pairs (nt)
236+ vns = vi. metadata[n]. vns
237+ if vals isa AbstractVector
238+ vals = unvectorize (vals)
239+ end
240+ if length (vns) == 1
241+ # Only one variable, assign the values to it
242+ DynamicPPL. setindex! (vi, vals, vns[1 ])
243+ else
244+ # Spread the values across the variables
245+ length (vns) == length (vals) || error (" Unequal number of variables and values" )
246+ for (vn, val) in zip (vns, vals)
247+ DynamicPPL. setindex! (vi, val, vn)
248+ end
249+ end
250+ end
228251end
229252
230253# unpack a vector if possible
@@ -335,12 +358,7 @@ function propose!!(rng::AbstractRNG, prev_state::MHState, model::Model, spl::MH,
335358
336359 # Make a new transition.
337360 model = DynamicPPL. setleafcontext (model, MHContext (rng))
338- densitymodel = AMH. DensityModel (
339- Base. Fix1 (
340- LogDensityProblems. logdensity,
341- DynamicPPL. LogDensityFunction (model, DynamicPPL. getlogjoint_internal, vi),
342- ),
343- )
361+ densitymodel = AMH. DensityModel (OldLogDensityFunction (model, vi))
344362 trans, _ = AbstractMCMC. step (rng, densitymodel, mh_sampler, prev_trans)
345363 # trans.params isa NamedTuple
346364 set_namedtuple! (vi, trans. params)
@@ -370,12 +388,7 @@ function propose!!(
370388
371389 # Make a new transition.
372390 model = DynamicPPL. setleafcontext (model, MHContext (rng))
373- densitymodel = AMH. DensityModel (
374- Base. Fix1 (
375- LogDensityProblems. logdensity,
376- DynamicPPL. LogDensityFunction (model, DynamicPPL. getlogjoint_internal, vi),
377- ),
378- )
391+ densitymodel = AMH. DensityModel (OldLogDensityFunction (model, vi))
379392 trans, _ = AbstractMCMC. step (rng, densitymodel, mh_sampler, prev_trans)
380393 # trans.params isa AbstractVector
381394 vi = DynamicPPL. unflatten (vi, trans. params)
0 commit comments