Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.38.4

Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on.

## 0.38.3

Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.38.3"
version = "0.38.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ LogDensityProblems = "2.1.2"
Mooncake = "0.4"
PrettyTables = "3"
ReverseDiff = "1.15.3"
StableRNGs = "1"
StableRNGs = "1"
2 changes: 2 additions & 0 deletions benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ chosen_combinations = [
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true),
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
retvals = model(rng)
vns = [VarName{k}() for k in keys(retvals)]
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
elseif varinfo_choice == :typed_vector
DynamicPPL.typed_vector_varinfo(rng, model)
elseif varinfo_choice == :untyped_vector
DynamicPPL.untyped_vector_varinfo(rng, model)
else
error("Unknown varinfo choice: $varinfo_choice")
end
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ DynamicPPL.reset!
DynamicPPL.update!
DynamicPPL.insert!
DynamicPPL.loosen_types!!
DynamicPPL.tighten_types
DynamicPPL.tighten_types!!
```

```@docs
Expand Down
4 changes: 3 additions & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ function tilde_assume!!(
end
# Neither of these set the `trans` flag so we have to do it manually if
# necessary.
insert_transformed_value && set_transformed!!(vi, true, vn)
if insert_transformed_value
vi = set_transformed!!(vi, true, vn)
end
# `accumulate_assume!!` wants untransformed values as the second argument.
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
# We always return the untransformed value here, as that will determine
Expand Down
4 changes: 2 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ add_io_context(io::IO) = IOContext(io, :compact => true, :limit => true)
show_varname(io::IO, varname::VarName) = print(io, varname)
function show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
# Attempt to make the type concrete in case the symbol is shared.
return _show_varname(io, map(identity, varname))
return _show_varname(io, [vn for vn in varname])
end
function _show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
# Print the first and last element of the array.
Expand Down Expand Up @@ -407,7 +407,7 @@ julia> @model function demo_incorrect()
end
demo_incorrect (generic function with 2 methods)

julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
# alert us to the issue of `x` being sampled twice.
model = demo_incorrect(); varinfo = VarInfo(model);

Expand Down
6 changes: 3 additions & 3 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ box:
- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring
any effects of linking
- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected
by linking, since transforms are only applied to random variables)
by linking, since transforms are only applied to random variables)

!!! note
By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the
Expand Down Expand Up @@ -146,7 +146,7 @@ struct LogDensityFunction{
is_supported(adtype) ||
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
# Get a set of dummy params to use for prep
x = map(identity, varinfo[:])
x = [val for val in varinfo[:]]
if use_closure(adtype)
prep = DI.prepare_gradient(
LogDensityAt(model, getlogdensity, varinfo), adtype, x
Expand Down Expand Up @@ -282,7 +282,7 @@ function LogDensityProblems.logdensity_and_gradient(
) where {M,F,V,AD<:ADTypes.AbstractADType}
f.prep === nothing &&
error("Gradient preparation not available; this should not happen")
x = map(identity, x) # Concretise type
x = [val for val in x] # Concretise type
# Make branching statically inferrable, i.e. type-stable (even if the two
# branches happen to return different types)
return if use_closure(f.adtype)
Expand Down
1 change: 1 addition & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
"Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.",
)
end
return vi
end

is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
Expand Down
4 changes: 2 additions & 2 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ Everything else is optional, and can be categorised into several groups:
1. _How to specify the results to compare against._

Once logp and its gradient has been calculated with the specified `adtype`,
it can optionally be tested for correctness. The exact way this is tested
it can optionally be tested for correctness. The exact way this is tested
is specified in the `test` parameter.

There are several options for this:
Expand Down Expand Up @@ -260,7 +260,7 @@ function run_ad(
if isnothing(params)
params = varinfo[:]
end
params = map(identity, params) # Concretise
params = [p for p in params] # Concretise

# Calculate log-density and gradient with the backend of interest
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
Expand Down
72 changes: 50 additions & 22 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ function untyped_vector_varinfo(
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy))
return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy))
end
function untyped_vector_varinfo(
model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()
Expand Down Expand Up @@ -789,18 +789,24 @@ function setval!(md::Metadata, val, vn::VarName)
return md.vals[getrange(md, vn)] = tovec(val)
end

function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName)
md = set_transformed!!(getmetadata(vi, vn), val, vn)
return Accessors.@set vi.metadata[getsym(vn)] = md
end

function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName)
set_transformed!!(getmetadata(vi, vn), val, vn)
return vi
md = set_transformed!!(getmetadata(vi, vn), val, vn)
return VarInfo(md, vi.accs)
end

function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName)
metadata.is_transformed[getidx(metadata, vn)] = val
return metadata
end

function set_transformed!!(vi::VarInfo, val::Bool)
for vn in keys(vi)
set_transformed!!(vi, val, vn)
vi = set_transformed!!(vi, val, vn)
end

return vi
Expand Down Expand Up @@ -977,7 +983,7 @@ function filter_subsumed(filter_vns, filtered_vns)
end

@generated function _link!!(
::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names}
::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names}
) where {metadata_names,vns_names}
expr = Expr(:block)
for f in metadata_names
Expand All @@ -988,7 +994,7 @@ end
expr.args,
quote
f_vns = vi.metadata.$f.vns
f_vns = filter_subsumed(vns.$f, f_vns)
f_vns = filter_subsumed(varnames.$f, f_vns)
if !isempty(f_vns)
if !is_transformed(vi, f_vns[1])
# Iterate over all `f_vns` and transform
Expand Down Expand Up @@ -1652,30 +1658,47 @@ end
Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to
the `VarInfo` `vi`, mutating if it makes sense.
"""
function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
if vi isa UntypedVarInfo
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist"
elseif vi isa NTVarInfo
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist"
end
function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution)
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist"
md = push!!(getmetadata(vi, vn), vn, val, dist)
return VarInfo(md, vi.accs)
end

function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution)
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist"
sym = getsym(vn)
if vi isa NTVarInfo && ~haskey(vi.metadata, sym)
meta = if ~haskey(vi.metadata, sym)
# The NamedTuple doesn't have an entry for this variable, let's add one.
val = tovec(r)
md = Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false]))
vi = Accessors.@set vi.metadata[sym] = md
_new_submetadata(vi, vn, val, dist)
else
meta = getmetadata(vi, vn)
push!(meta, vn, r, dist)
push!!(getmetadata(vi, vn), vn, val, dist)
end

vi = Accessors.@set vi.metadata[sym] = meta
return vi
end

function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...)
push!(getmetadata(vi, vn), vn, val, args...)
return vi
"""
_new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas}

Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing
SubMetas.
"""
@generated function _new_submetadata(
vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist
) where {Names,SubMetas}
has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters)
return if has_vnv
:(return _new_vnv_submetadata(vn, r, dist))
else
:(return _new_metadata_submetadata(vn, r, dist))
end
end

_new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r])

function _new_metadata_submetadata(vn, r, dist)
val = tovec(r)
return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false]))
end

function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...)
Expand All @@ -1700,6 +1723,11 @@ function Base.push!(meta::Metadata, vn, r, dist)
return meta
end

function BangBang.push!!(meta::Metadata, vn, r, dist)
push!(meta, vn, r, dist)
return meta
end

function Base.delete!(vi::VarInfo, vn::VarName)
delete!(getmetadata(vi, vn), vn)
return vi
Expand Down
Loading