diff --git a/src/varinfo.jl b/src/varinfo.jl index a90b81488..486d24191 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1297,6 +1297,10 @@ function _link_metadata!!( metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) set_transformed!(metadata, true, vn) end + # Linking can often change the sizes of variables, causing inactive elements. We don't + # want to keep them around, since typically linking is done once and then the VarInfo + # is evaluated multiple times. Hence we contiguify here. + metadata = contiguify!(metadata) return metadata, cumulative_logjac end @@ -1465,6 +1469,10 @@ function _invlink_metadata!!( metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) set_transformed!(metadata, false, vn) end + # Linking can often change the sizes of variables, causing inactive elements. We don't + # want to keep them around, since typically linking is done once and then the VarInfo + # is evaluated multiple times. Hence we contiguify here. + metadata = contiguify!(metadata) return metadata, cumulative_inv_logjac end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 2c66e1245..17b851d1d 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -341,10 +341,13 @@ function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) vnv_left.num_inactive == vnv_right.num_inactive end -function is_concretely_typed(vnv::VarNamedVector) - return isconcretetype(eltype(vnv.varnames)) && - isconcretetype(eltype(vnv.vals)) && - isconcretetype(eltype(vnv.transforms)) +function is_tightly_typed(vnv::VarNamedVector) + k = eltype(vnv.varnames) + v = eltype(vnv.vals) + t = eltype(vnv.transforms) + return (isconcretetype(k) || k === Union{}) && + (isconcretetype(v) || v === Union{}) && + (isconcretetype(t) || t === Union{}) end getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] @@ -880,7 +883,16 @@ function loosen_types!!( return if vn_type == K && val_type == V && transform_type == T vnv elseif isempty(vnv) - VarNamedVector(vn_type[], val_type[], transform_type[]) + VarNamedVector( + Dict{vn_type,Int}(), + Vector{vn_type}(), + UnitRange{Int}[], + Vector{val_type}(), + Vector{transform_type}(), + BitVector(), + Dict{Int,Int}(); + check_consistency=false, + ) else # TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but # then here always revert to Vector. @@ -944,7 +956,7 @@ julia> vnv_tight.transforms ``` """ function tighten_types!!(vnv::VarNamedVector) - return if is_concretely_typed(vnv) + return if is_tightly_typed(vnv) # There can not be anything to tighten, so short-circuit. vnv elseif isempty(vnv) @@ -1020,6 +1032,7 @@ function insert_internal!!( end vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform)) insert_internal!(vnv, val, vn, transform) + vnv = tighten_types!!(vnv) return vnv end @@ -1029,6 +1042,7 @@ function update_internal!!( transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) update_internal!(vnv, val, vn, transform) + vnv = tighten_types!!(vnv) return vnv end @@ -1104,6 +1118,9 @@ care about them. This is in a sense the reverse operation of `vnv[:]`. +The return value may share memory with the input `vnv`, and thus one can not be mutated +safely without affecting the other. + Unflatten recontiguifies the internal storage, getting rid of any inactive entries. # Examples @@ -1125,15 +1142,20 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) ), ) end - new_ranges = deepcopy(vnv.ranges) - recontiguify_ranges!(new_ranges) + new_ranges = vnv.ranges + num_inactive = vnv.num_inactive + if has_inactive(vnv) + new_ranges = recontiguify_ranges!(new_ranges) + num_inactive = Dict{Int,Int}() + end return VarNamedVector( vnv.varname_to_index, vnv.varnames, new_ranges, vals, vnv.transforms, - vnv.is_unconstrained; + vnv.is_unconstrained, + num_inactive; check_consistency=false, ) end @@ -1428,6 +1450,9 @@ julia> vnv[@varname(x)] # All the values are still there. ``` """ function contiguify!(vnv::VarNamedVector) + if !has_inactive(vnv) + return vnv + end # Extract the re-contiguified values. # NOTE: We need to do this before we update the ranges. old_vals = copy(vnv.vals) diff --git a/test/Project.toml b/test/Project.toml index de2160f4f..c96087d66 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" @@ -34,6 +35,7 @@ AbstractMCMC = "5" AbstractPPL = "0.13" Accessors = "0.1" Aqua = "0.8" +BangBang = "0.4" Bijectors = "0.15.1" Combinatorics = "1" DifferentiationInterface = "0.6.41, 0.7" diff --git a/test/accumulators.jl b/test/accumulators.jl index e45dfb028..a7175f019 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -117,7 +117,7 @@ using DynamicPPL: @test at_all64[:LogLikelihood] == ll_f64 @test haskey(AccumulatorTuple(lp_f64), Val(:LogPrior)) - @test ~haskey(AccumulatorTuple(lp_f64), Val(:LogLikelihood)) + @test !haskey(AccumulatorTuple(lp_f64), Val(:LogLikelihood)) @test length(AccumulatorTuple(lp_f64, ll_f64)) == 2 @test keys(at_all64) == (:LogPrior, :LogLikelihood) @test collect(at_all64) == [lp_f64, ll_f64] diff --git a/test/runtests.jl b/test/runtests.jl index b6a3f7bf6..7a9c12525 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using ADTypes using DynamicPPL using AbstractMCMC using AbstractPPL +using BangBang: delete!!, setindex!! using Bijectors using DifferentiationInterface using Distributions diff --git a/test/varinfo.jl b/test/varinfo.jl index 6b31fbe91..a1a1b370f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -71,10 +71,10 @@ end r = rand(dist) @test isempty(vi) - @test ~haskey(vi, vn) + @test !haskey(vi, vn) @test !(vn in keys(vi)) vi = push!!(vi, vn, r, dist) - @test ~isempty(vi) + @test !isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @@ -95,7 +95,7 @@ end vi = empty!!(vi) @test isempty(vi) vi = push!!(vi, vn, r, dist) - @test ~isempty(vi) + @test !isempty(vi) end test_base(VarInfo()) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index b764d517b..3a327c147 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -79,7 +79,7 @@ function relax_container_types(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) end function relax_container_types(vnv::DynamicPPL.VarNamedVector, vns, vals) if need_varnames_relaxation(vnv, vns, vals) - varname_to_index_new = convert(OrderedDict{VarName,Int}, vnv.varname_to_index) + varname_to_index_new = convert(Dict{VarName,Int}, vnv.varname_to_index) varnames_new = convert(Vector{VarName}, vnv.varnames) else varname_to_index_new = vnv.varname_to_index @@ -517,7 +517,7 @@ end @testset "deterministic" begin n = 5 vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true])) + vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) @test !DynamicPPL.has_inactive(vnv) # Growing should not create inactive ranges. for i in 1:n @@ -543,7 +543,7 @@ end @testset "random" begin n = 5 vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true])) + vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) @test !DynamicPPL.has_inactive(vnv) # Insert a bunch of random-length vectors. @@ -579,6 +579,91 @@ end @test is_transformed(vnv, @varname(t[1])) @test subset(vnv, vns) == vnv end + + @testset "loosen and tighten types" begin + """ + test_tightenability(vnv::VarNamedVector) + + Test that tighten_types!! is a no-op on `vnv`. + """ + function test_tightenability(vnv::DynamicPPL.VarNamedVector) + @test vnv == DynamicPPL.tighten_types!!(deepcopy(vnv)) + # TODO(mhauru) We would like to check something more stringent here, namely that + # the operation is compiled to a direct no-op, with no instructions at all. I + # don't know how to do that though, so for now we just check that it doesn't + # allocate. + @allocations(DynamicPPL.tighten_types!!(vnv)) == 0 + return nothing + end + + vn = @varname(a[1]) + # Test that tighten_types!! is a no-op on an empty VarNamedVector. + vnv = DynamicPPL.VarNamedVector() + @test DynamicPPL.is_tightly_typed(vnv) + test_tightenability(vnv) + # Also check that it literally returns the same object, and both tighten and loosen + # are type stable. + @test vnv === DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) + # Likewise for a VarNamedVector with something pushed into it. + vnv = DynamicPPL.VarNamedVector() + vnv = setindex!!(vnv, 1.0, vn) + @test DynamicPPL.is_tightly_typed(vnv) + test_tightenability(vnv) + @test vnv === DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) + # Likewise for a VarNamedVector with abstract element-types, when that is needed for + # the current contents because mixed types have been pushed into it. However, this + # time, since the types are only as tight as they can be, but not actually concrete, + # tighten_types!! can't be type stable. + vnv = DynamicPPL.VarNamedVector() + vnv = setindex!!(vnv, 1.0, vn) + vnv = setindex!!(vnv, 2, @varname(b)) + @test !DynamicPPL.is_tightly_typed(vnv) + test_tightenability(vnv) + @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) + # Likewise when first mixed types are pushed, but then deleted. + vnv = DynamicPPL.VarNamedVector() + vnv = setindex!!(vnv, 1.0, vn) + vnv = setindex!!(vnv, 2, @varname(b)) + @test !DynamicPPL.is_tightly_typed(vnv) + vnv = delete!!(vnv, vn) + @test DynamicPPL.is_tightly_typed(vnv) + test_tightenability(vnv) + @test vnv === DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) + + # Test that loosen_types!! does really loosen them and that tighten_types!! reverts + # that. + vnv = DynamicPPL.VarNamedVector() + vnv = setindex!!(vnv, 1.0, vn) + @test DynamicPPL.is_tightly_typed(vnv) + k = eltype(vnv.varnames) + e = eltype(vnv.vals) + t = eltype(vnv.transforms) + # Loosen key type. + vnv = @inferred DynamicPPL.loosen_types!!(vnv, VarName, e, t) + @test !DynamicPPL.is_tightly_typed(vnv) + vnv = DynamicPPL.tighten_types!!(vnv) + @test DynamicPPL.is_tightly_typed(vnv) + # Loosen element type + vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, Real, t) + @test !DynamicPPL.is_tightly_typed(vnv) + vnv = DynamicPPL.tighten_types!!(vnv) + @test DynamicPPL.is_tightly_typed(vnv) + # Loosen transformation type + vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, Function) + @test !DynamicPPL.is_tightly_typed(vnv) + vnv = DynamicPPL.tighten_types!!(vnv) + @test DynamicPPL.is_tightly_typed(vnv) + # Loosening to the same types as currently should do nothing. + vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, t) + @test DynamicPPL.is_tightly_typed(vnv) + @allocations(DynamicPPL.loosen_types!!(vnv, k, e, t)) == 0 + end end @testset "VarInfo + VarNamedVector" begin