diff --git a/Project.toml b/Project.toml index fafeec2ba..df7357d3c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.31.1" +version = "0.31.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/threadsafe.jl b/src/threadsafe.jl index ec890a674..f7ce569b3 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -178,6 +178,12 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<: return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) end +vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo) +vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn) +function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) + return vector_getranges(vi.varinfo, vns) +end + function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) return set_retained_vns_del_by_spl!(vi.varinfo, spl) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 4cf1f1b02..2c07d4298 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -202,6 +202,15 @@ function VarInfo( end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) +""" + vector_length(varinfo::VarInfo) + +Return the length of the vector representation of `varinfo`. +""" +vector_length(varinfo::VarInfo) = length(varinfo.metadata) +vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) +vector_length(md::Metadata) = sum(length, md.ranges) + unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) # TODO: deprecate. @@ -626,7 +635,72 @@ setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. """ function getranges(vi::VarInfo, vns::Vector{<:VarName}) - return mapreduce(vn -> getrange(vi, vn), vcat, vns; init=Int[]) + return map(Base.Fix1(getrange, vi), vns) +end + +""" + vector_getrange(varinfo::VarInfo, varname::VarName) + +Return the range corresponding to `varname` in the vector representation of `varinfo`. +""" +vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn) +function vector_getrange(vi::TypedVarInfo, vn::VarName) + offset = 0 + for md in values(vi.metadata) + # First, we need to check if `vn` is in `md`. + # In this case, we can just return the corresponding range + offset. + haskey(md, vn) && return getrange(md, vn) .+ offset + # Otherwise, we need to get the cumulative length of the ranges in `md` + # and add it to the offset. + offset += sum(length, md.ranges) + end + # If we reach this point, `vn` is not in `vi.metadata`. + throw(KeyError(vn)) +end + +""" + vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName}) + +Return the range corresponding to `varname` in the vector representation of `varinfo`. +""" +function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName}) + return map(Base.Fix1(vector_getrange, varinfo), varname) +end +# Specialized version for `TypedVarInfo`. +function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) + # TODO: Does it help if we _don't_ convert to a vector here? + metadatas = collect(values(varinfo.metadata)) + # Extract the offsets. + offsets = cumsum(map(vector_length, metadatas)) + # Extract the ranges from each metadata. + ranges = Vector{UnitRange{Int}}(undef, length(vns)) + # Need to keep track of which ones we've seen. + not_seen = fill(true, length(vns)) + for (i, metadata) in enumerate(metadatas) + vns_metadata = filter(Base.Fix1(haskey, metadata), vns) + # If none of the variables exist in the metadata, we return an empty array. + isempty(vns_metadata) && continue + # Otherwise, we extract the ranges. + offset = i == 1 ? 0 : offsets[i - 1] + for vn in vns_metadata + r_vn = getrange(metadata, vn) + # Get the index, so we return in the same order as `vns`. + # NOTE: There might be duplicates in `vns`, so we need to handle that. + indices = findall(==(vn), vns) + for idx in indices + not_seen[idx] = false + ranges[idx] = r_vn .+ offset + end + end + end + # Raise key error if any of the variables were not found. + if any(not_seen) + inds = findall(not_seen) + # Just use a `convert` to get the same type as the input; don't want to confuse by overly + # specilizing the types in the error message. + throw(KeyError(convert(typeof(vns), vns[inds]))) + end + return ranges end """ @@ -1314,13 +1388,13 @@ end function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) # TODO: Use inplace versions to avoid allocations - yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn)) + yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(md, vn)) # Determine the new range. - start = first(getrange(vi, vn)) + start = first(getrange(md, vn)) # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. - setrange!(vi, vn, start:(start + length(yvec) - 1)) + setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. - setval!(vi, yvec, vn) + setval!(md, yvec, vn) acclogp!!(vi, -logjac) return vi end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index a5097602d..039b549d6 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1036,6 +1036,8 @@ function replace_raw_storage(vnv::VarNamedVector, ::Val{space}, vals) where {spa return replace_raw_storage(vnv, vals) end +vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv) + """ unflatten(vnv::VarNamedVector, vals::AbstractVector) diff --git a/test/varinfo.jl b/test/varinfo.jl index c45fb47e0..ae4319904 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -813,4 +813,46 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test DynamicPPL.istrans(varinfo2, vn) end end + + # NOTE: It is not yet clear if this is something we want from all varinfo types. + # Hence, we only test the `VarInfo` types here. + @testset "vector_getranges for `VarInfo`" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + nt = DynamicPPL.TestUtils.rand_prior_true(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, nt, vns; include_threadsafe=true + ) + # Only keep `VarInfo` types. + varinfos = filter( + Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos + ) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + x = values_as(varinfo, Vector) + + # Let's just check all the subsets of `vns`. + @testset "$(convert(Vector{Any},vns_subset))" for vns_subset in + combinations(vns) + ranges = DynamicPPL.vector_getranges(varinfo, vns_subset) + @test length(ranges) == length(vns_subset) + for (r, vn) in zip(ranges, vns_subset) + @test x[r] == DynamicPPL.tovec(varinfo[vn]) + end + end + + # Let's try some failure cases. + @test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[] + # Non-existent variables. + @test_throws KeyError DynamicPPL.vector_getranges( + varinfo, [VarName{gensym("vn")}()] + ) + @test_throws KeyError DynamicPPL.vector_getranges( + varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()] + ) + # Duplicate variables. + ranges_duplicated = DynamicPPL.vector_getranges(varinfo, repeat(vns, 2)) + @test x[reduce(vcat, ranges_duplicated)] == repeat(x, 2) + end + end + end end