Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subset and merge for VarInfo (clean version) #544

Merged
merged 28 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
028a81a
added `subset` which can extract a subset of the varinfo
torfjelde Oct 8, 2023
caa6e25
added testing of `subset` for `VarInfo`
torfjelde Oct 8, 2023
cac7fa8
formatting
torfjelde Oct 8, 2023
5e41c4f
added implementation of `merge` for `VarInfo` and tests for it
torfjelde Oct 8, 2023
d5a2631
more tests
torfjelde Oct 8, 2023
0ade696
formatting
torfjelde Oct 8, 2023
db21844
improved merge_metadata for NamedTuple inputs
torfjelde Oct 9, 2023
1dbca4c
added proper handling of the `vals` in `subset`
torfjelde Oct 9, 2023
b67288f
added docs for `subset` and `merge`
torfjelde Oct 9, 2023
e43029e
added `subset` and `merge` to documentation
torfjelde Oct 9, 2023
cd4033d
formatting
torfjelde Oct 9, 2023
8f47dfe
made merge and subset part of the AbstractVarInfo interface
torfjelde Oct 13, 2023
aba9008
added implementations `subset` and `merge` for `SimpleVarInfo`
torfjelde Oct 13, 2023
3b621ae
follow standard merge semantics where the right one takes precedence
torfjelde Oct 13, 2023
2c2c90b
added proper testing of merge and subset for SimpleVarInfo too
torfjelde Oct 13, 2023
5c1ece3
forgotten inclusion in previous commit
torfjelde Oct 13, 2023
cfff96c
Update src/simple_varinfo.jl
torfjelde Oct 13, 2023
ed5d948
remove two-argument impl of merge
torfjelde Oct 13, 2023
00c36cf
formatting
torfjelde Oct 13, 2023
cf02816
forgot to add more formatting
torfjelde Oct 13, 2023
d02cb61
Merge branch 'master' into torfjelde/subset-and-merge
torfjelde Oct 13, 2023
7f01ada
removed 2-arg version of merge for abstract varinfo in favour of 3-ar…
torfjelde Oct 13, 2023
14105e0
allow inclusion of threadsafe varinfo in setup_varinfos
torfjelde Oct 13, 2023
c164d32
more tests for thread safe varinfo
torfjelde Oct 13, 2023
743162a
bugfixes for link and invlink methods when using thread safe varinfo
torfjelde Oct 13, 2023
dc9ad94
attempt at fixing docs
torfjelde Oct 13, 2023
2f320e6
fixed missing test coverage
torfjelde Oct 14, 2023
d3a9b56
formatting
torfjelde Oct 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ DynamicPPL.reconstruct
#### Utils

```@docs
Base.merge(::AbstractVarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export AbstractVarInfo,
SimpleVarInfo,
push!!,
empty!!,
subset,
getlogp,
setlogp!!,
acclogp!!,
Expand Down
161 changes: 161 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ struct StaticTransformation{F} <: AbstractTransformation
bijector::F
end

"""
merge_transformations(transformation_left, transformation_right)
Merge two transformations.
The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref).
"""
function merge_transformations(::NoTransformation, ::NoTransformation)
return NoTransformation()
end
function merge_transformations(::DynamicTransformation, ::DynamicTransformation)
return DynamicTransformation()
end
function merge_transformations(left::StaticTransformation, right::StaticTransformation)
return StaticTransformation(merge_bijectors(left.bijector, right.bijector))
end

function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform)
return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs))
end

"""
default_transformation(model::Model[, vi::AbstractVarInfo])
Expand Down Expand Up @@ -337,6 +358,146 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP
return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)}))
end

# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which
# might result in a `Vector{Any}`.
"""
subset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we would not already have so many getindex methods, I would have thought that getindex would be a natural name for this function. But maybe it's still an option?

Then we could have getindex(::AbstractVarInfo, ::AbstractVector{<:VarName}) -> AbstractVarInfo and getindex(::T, ::VarName) -> typeof_varname_variate, similar to [1,2,3][[1,3]] = [1, 3] and [1,2,3][2] = 2.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really like this yes, but I also really don't want to touch getindex in this codebase 😅

Happy to make this a long-term goal or something though!

Subset a `varinfo` to only contain the variables `vns`.
!!! warning
The ordering of the variables in the resulting `varinfo` is _not_
guaranteed to follow the ordering of the variables in `varinfo`.
Hence care must be taken, in particular when used in conjunction with
other methods which uses the vector-representation of the `varinfo`,
e.g. `getindex(varinfo, sampler)`.
# Examples
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL)
julia> @model function demo()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x = Vector{Float64}(undef, 2)
x[1] ~ Normal(m, sqrt(s))
x[2] ~ Normal(m, sqrt(s))
end
demo (generic function with 2 methods)
julia> model = demo();
julia> varinfo = VarInfo(model);
julia> keys(varinfo)
4-element Vector{VarName}:
s
m
x[1]
x[2]
julia> for (i, vn) in enumerate(keys(varinfo))
varinfo[vn] = i
end
julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0
julia> # Extract one with only `m`.
varinfo_subset1 = subset(varinfo, [@varname(m),]);
julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
m
julia> varinfo_subset1[@varname(m)]
2.0
julia> # Extract one with both `s` and `x[2]`.
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]);
julia> keys(varinfo_subset2)
2-element Vector{VarName}:
s
x[2]
julia> varinfo_subset2[[@varname(s), @varname(x[2])]]
2-element Vector{Float64}:
1.0
4.0
```
`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref)
```jldoctest varinfo-subset
julia> # Merge the two.
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2);
julia> keys(varinfo_subset_merged)
3-element Vector{VarName}:
m
s
x[2]
julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]]
3-element Vector{Float64}:
1.0
2.0
4.0
julia> # Merge the two with the original.
varinfo_merged = merge(varinfo, varinfo_subset_merged);
julia> keys(varinfo_merged)
4-element Vector{VarName}:
s
m
x[1]
x[2]
julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0
```
# Notes
## Type-stability
!!! warning
This function is only type-stable when `vns` contains only varnames
with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will
be type-stable, but `[@varname(m[1]), @varname(x)]` will not be.
"""
function subset end

"""
merge(varinfo, other_varinfos...)
Merge varinfos into one, giving precedence to the right-most varinfo when sensible.
This is particularly useful when combined with [`subset(varinfo, vns)`](@ref).
See docstring of [`subset(varinfo, vns)`](@ref) for examples.
"""
Base.merge(varinfo::AbstractVarInfo) = varinfo
# Define 3-argument version so 2-argument version will error if not implemented.
function Base.merge(
varinfo1::AbstractVarInfo,
varinfo2::AbstractVarInfo,
varinfo3::AbstractVarInfo,
varinfo_others::AbstractVarInfo...,
)
return merge(Base.merge(varinfo1, varinfo2), varinfo3, varinfo_others...)
end

# Transformations
"""
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
Expand Down
45 changes: 45 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,51 @@ function Base.eltype(
return V
end

# `subset`
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
return Setfield.@set varinfo.values = _subset(varinfo.values, vns)
end

function _subset(x::AbstractDict, vns)
# NOTE: This requires `vns` to be explicitly present in `x`.
if any(!Base.Fix1(haskey, x), vns)
throw(
ArgumentError(
"Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " *
"For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " *
"`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.",
),
)
end
C = ConstructionBase.constructorof(typeof(x))
return C(vn => x[vn] for vn in vns)
end

function _subset(x::NamedTuple, vns)
# NOTE: Here we can only handle `vns` that contain the `IdentityLens`.
if any(Base.Fix1(!==, Setfield.IdentityLens()) getlens, vns)
throw(
ArgumentError(
"Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " *
"For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.",
),
)
end

syms = map(getsym, vns)
return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix2(getindex, x), syms)))
end

# `merge`
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
values = merge(varinfo_left.values, varinfo_right.values)
logp = getlogp(varinfo_right)
transformation = merge_transformations(
varinfo_left.transformation, varinfo_right.transformation
)
return SimpleVarInfo(values, logp, transformation)
end

# Context implementations
# NOTE: Evaluations, i.e. those without `rng` are shared with other
# implementations of `AbstractVarInfo`.
Expand Down
17 changes: 14 additions & 3 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal
end

"""
setup_varinfos(model::Model, example_values::NamedTuple, varnames)
setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false)
Return a tuple of instances for different implementations of `AbstractVarInfo` with
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
of the varinfo instances.
"""
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
function setup_varinfos(
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
)
# VarInfo
vi_untyped = VarInfo()
model(vi_untyped)
Expand All @@ -56,12 +61,18 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped)))

lp = getlogp(vi_typed)
return map((
varinfos = map((
vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
end

if include_threadsafe
varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo deepcopy, varinfos)...)
end

return varinfos
end

"""
Expand Down
56 changes: 52 additions & 4 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,56 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl
function link!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = link!!(t, vi.varinfo, spl, model)
end

function invlink!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model)
end

function link(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = link(t, vi.varinfo, spl, model)
end

function invlink(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = invlink(t, vi.varinfo, spl, model)
end

# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure
# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates
# to define `getlogp(vi)`.
function link!!(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
end

function invlink!!(
::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return settrans!!(
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
NoTransformation(),
)
end

function link(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, deepcopy(vi), spl, model)
end

function invlink(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, deepcopy(vi), spl, model)
end

function maybe_invlink_before_eval!!(
Expand Down Expand Up @@ -192,3 +223,20 @@ istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)

getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)

function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector)
return Setfield.@set vi.varinfo = unflatten(vi.varinfo, x)
end
function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector)
return Setfield.@set vi.varinfo = unflatten(vi.varinfo, spl, x)
end

function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName})
return Setfield.@set varinfo.varinfo = subset(varinfo.varinfo, vns)
end

function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo)
return Setfield.@set varinfo_left.varinfo = merge(
varinfo_left.varinfo, varinfo_right.varinfo
)
end
Loading