Skip to content

Support DPPL 0.37 #2550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Distributions = "0.25.77"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.36"
DynamicPPL = "0.37"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.8.8"
Expand Down
33 changes: 17 additions & 16 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
f = Optimisation.OptimLogDensity(model, vi)
init_vals = DynamicPPL.getparams(f.ldf)
optimizer = Optim.LBFGS()
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -57,8 +57,8 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
f = Optimisation.OptimLogDensity(model, vi)
init_vals = DynamicPPL.getparams(f.ldf)
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -74,8 +74,9 @@ function Optim.optimize(
end

function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
f = Optimisation.OptimLogDensity(model, vi)
return _optimize(f, args...; kwargs...)
end

"""
Expand Down Expand Up @@ -104,8 +105,8 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
vi = DynamicPPL.setaccs!!(
VarInfo(model),
(LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()),
)

f = Optimisation.OptimLogDensity(model, vi)
init_vals = DynamicPPL.getparams(f.ldf)
optimizer = Optim.LBFGS()
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -127,8 +128,8 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
vi = DynamicPPL.setaccs!!(
VarInfo(model),
(LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()),
)

f = Optimisation.OptimLogDensity(model, vi)
init_vals = DynamicPPL.getparams(f.ldf)
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -144,9 +145,11 @@ function Optim.optimize(
end

function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
vi = DynamicPPL.setaccs!!(
VarInfo(model),
(LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()),
)

f = Optimisation.OptimLogDensity(model, vi)
return _optimize(f, args...; kwargs...)
end

"""
_optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)

Expand All @@ -166,7 +169,7 @@ function _optimize(
# whether initialisation is really necessary at all
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
vi = DynamicPPL.link(vi, f.ldf.model)
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
f = Optimisation.OptimLogDensity(f.ldf.model, vi; adtype=f.ldf.adtype)
init_vals = DynamicPPL.getparams(f.ldf)

# Optimize!
Expand All @@ -183,9 +186,7 @@ function _optimize(
# Get the optimum in unconstrained space. `getparams` does the invlinking.
vi = f.ldf.varinfo
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
logdensity_optimum = Optimisation.OptimLogDensity(
f.ldf.model, vi_optimum, f.ldf.context
)
logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype)
logdensity_optimum = Optimisation.OptimLogDensity(
f.ldf.model, vi_optimum; adtype=f.ldf.adtype
)

vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
varnames = map(Symbol ∘ first, vns_vals_iter)
vals = map(last, vns_vals_iter)
Expand Down
11 changes: 3 additions & 8 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ using DynamicPPL:
SampleFromPrior,
SampleFromUniform,
DefaultContext,
PriorContext,
LikelihoodContext,
set_flag!,
unset_flag!
using Distributions, Libtask, Bijectors
Expand Down Expand Up @@ -75,7 +73,6 @@ export InferenceAlgorithm,
RepeatSampler,
Prior,
assume,
observe,
predict,
externalsampler

Expand Down Expand Up @@ -182,12 +179,10 @@ function AbstractMCMC.step(
state=nothing;
kwargs...,
)
vi = VarInfo()
vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.LogPriorAccumulator(),))
vi = last(
DynamicPPL.evaluate!!(
model,
VarInfo(),
SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()),
),
DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior())),
DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior()))

)
return vi, nothing
end
Expand Down
20 changes: 6 additions & 14 deletions src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function AbstractMCMC.step(
rng,
EllipticalSliceSampling.ESSModel(
ESSPrior(model, spl, vi),
DynamicPPL.LogDensityFunction(
DynamicPPL.LogDensityFunction{:LogLikelihood}(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
),
),
Expand All @@ -59,7 +59,7 @@ function AbstractMCMC.step(

# update sample and log-likelihood
vi = DynamicPPL.unflatten(vi, sample)
vi = setlogp!!(vi, state.loglikelihood)
vi = setloglikelihood!!(vi, state.loglikelihood)

return Transition(model, vi), vi
end
Expand Down Expand Up @@ -108,20 +108,12 @@ end
# Mean of prior distribution
Distributions.mean(p::ESSPrior) = p.μ

# Evaluate log-likelihood of proposals
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}

(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)

function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
rng::Random.AbstractRNG, ctx::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
)
return DynamicPPL.tilde_assume(
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
)
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
end

function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter v1.0.62] reported by reviewdog 🐶

Suggested change
function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi)
function DynamicPPL.tilde_observe!!(
ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi
)

return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi)
end
12 changes: 7 additions & 5 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context)
#
# Purpose: avoid triggering resampling of variables we're conditioning on.
# - Using standard `DynamicPPL.condition` results in conditioned variables being treated
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`.
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`.
# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to
# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable
# rather than only for the "true" observations.
Expand Down Expand Up @@ -177,24 +177,26 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
DynamicPPL.tilde_assume(child_context, right, vn, vi)
elseif has_conditioned_gibbs(context, vn)
# Short-circuit the tilde assume if `vn` is present in `context`.
value, lp, _ = DynamicPPL.tilde_assume(
# TODO(mhauru) Fix accumulation here. In this branch anything that gets
# accumulated just gets discarded with `_`.
value, _ = DynamicPPL.tilde_assume(
child_context, right, vn, get_global_varinfo(context)
)
value, lp, vi
value, vi
else
# If the varname has not been conditioned on, nor is it a target variable, its
# presumably a new variable that should be sampled from its prior. We need to add
# this new variable to the global `varinfo` of the context, but not to the local one
# being used by the current sampler.
value, lp, new_global_vi = DynamicPPL.tilde_assume(
value, new_global_vi = DynamicPPL.tilde_assume(
child_context,
DynamicPPL.SampleFromPrior(),
right,
vn,
get_global_varinfo(context),
)
set_global_varinfo!(context, new_global_vi)
value, lp, vi
value, vi
end
end

Expand Down
4 changes: 0 additions & 4 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,6 @@ function DynamicPPL.assume(
return DynamicPPL.assume(dist, vn, vi)
end

function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi)
return DynamicPPL.observe(d, value, vi)
end

####
#### Default HMC stepsize and mass matrix adaptor
####
Expand Down
4 changes: 0 additions & 4 deletions src/mcmc/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,3 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName
end
return r, 0, vi
end

function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi)
return logpdf(dist, value), vi
end
4 changes: 0 additions & 4 deletions src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,3 @@ function DynamicPPL.assume(
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
return retval
end

function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)
return DynamicPPL.observe(SampleFromPrior(), d, value, vi)
end
22 changes: 12 additions & 10 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,11 @@ function DynamicPPL.assume(
return r, lp, vi
end

function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
# NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`.
return logpdf(dist, value), trace_local_varinfo_maybe(vi)
end
# TODO(mhauru) Fix this.
# function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
# # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`.
# return logpdf(dist, value), trace_local_varinfo_maybe(vi)
# end

function DynamicPPL.acclogp!!(
context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
Expand All @@ -391,12 +392,13 @@ function DynamicPPL.acclogp!!(
return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp)
end

function DynamicPPL.acclogp_observe!!(
context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
)
Libtask.produce(logp)
return trace_local_varinfo_maybe(varinfo)
end
# TODO(mhauru) Fix this.
# function DynamicPPL.acclogp_observe!!(
# context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
# )
# Libtask.produce(logp)
# return trace_local_varinfo_maybe(varinfo)
# end

# Convenient constructor
function AdvancedPS.Trace(
Expand Down
Loading
Loading