Skip to content

Commit c38e65f

Browse files
torfjeldegithub-actions[bot]mhaurusunxd3
authored
Attempt at implementation of VarNamedVector (Metadata alternative) (#555)
* initial implementation of VarNameVector * added some hacky getval and getdist get things to work for VarInfo * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added arbitrary metadata field as discussed * renamed idcs to varname_to_index * renamed vns to varnames for VarNameVector * added keys impl for Metadata * added push! and update! for VarNameVector * added getindex_raw! and setindex_raw! for VarNameVector * added `iterate` and `convert` (for `AbstractDict) impls for `VarNameVector` * make the key and eltype part of the `VarNameVector` type * added more tests for VarNameVector * formatting * more testing for VarNameVector * minor changes to some comments * added a bunch more tests for VarNameVector + several bugfixes in the process * formatting * added `similar` implementation for `VarNameVector` * formatting * removed debug statement * made VarInfo slighly more generic wrt. underlying metadata * fixed incorrect behavior in `keys` for `Metadata` * minor style changes to VarNameVector tests * style * added testing of `update!` with smaller sizes and fixed bug related to this * formatting * move functionality related to `push!` for `VarNameVector` into `push!` * Update src/varnamevector.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * several fixes to make sampling with VarNameVector + initiall tests for sampling with VarNameVector * VarInfo + VarNameVector tests for all demo models * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added docs on the design of `VarNameVector` * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added note on `update!` * further elaboration of the design of `VarInfo` and `VarNameVector` * more writing improvements * added docstring to `has_inactive_ranges` and `inactive_ranges_sweep!` * moved docs on `VarInfo` design to a separate internals section * writing improvements for internal docs * further motivation of the design choices made in `VarNameVector` * improved writing * VarNameVector is now grown as much as needed * updated `delete!` * Significant changes to implementation of `VarNameVector`: - "delete-by-mark" is now replaced by proper deletion. - `inactive_ranges` replaced by `num_inactive`, which only keeps track of the number of inactive entries for a given `VarName. - `VarNameVector` is now a "grow-as-needed" structure where the underlying also mimics the order that the user experiences.` * added `copy` when constructing `VectorVarInfo` from `VarInfo` * added missing `isempty` impl * remove impl of `iterate` and instead implemented `pairs` and `values` iterators * added missing `empty!` for `num_inactive` * removed redundant `shift_left!` methd * fixed `delete!` for `VarNameVector` * added `is_contiguous` as an alterantive to `!has_inactive` * updates to internal docs * renamed `sweep_inactive_ranges!` to `contiguify!` * improvements to internal docs * more improvements to internal docs * moved additional methods description in internals to earlier in the doc * moved internals docs to a separate directory and split into files * more improvements to internals doc * formatting * added tests for `delete!` and fixed reference to old method * addition to `delete!` test * added `values_as` impls for `VarNameVector` * added docs for `replace_valus` and `values_as` for `VarNameVector` * fixed doctest * formatting * temporarily disable doctests so we can build docs * added missing compat entry for ForwardDiff in docs * moved some shared code into methods to make things a bit cleaner * added impl of `merge` for `VarNameVector` * renamed a few variables in `merge` impl for `VarNameVector` * forgot to include some changes in previous commit * added impl of `subset` for `VarNameVector` * fixed `pairs` impl for `VarNameVector` * added missing impl of `subset` for `VectorVarInfo` * added missing impl of `merge_metadata` for `VarNameVector` * added a bunch of `from_vec_transform` and `tovec` impls to make `VarNameVector` work with `Cholesky`, etc. * make default args use `from_vec_transform` rather than `FromVec` * fixed `values_as` fro `VarInfo` with `VarNameVector` as `metadata` * fixed impl of `getindex_raw` when using integer index for `VarNameVector` * added tests for `getindex` with `Int` index for `VarNameVector` * fix for `setindex!` and `setindex_raw!` for `VarNameVector` * introduction of `from_vec_transform` and `tovec` and its usage in `VarInfo` * moved definition of `is_splat_symbol` to the file where it's used * added `VarInfo` constructor with vector input for `VectorVarInfo` * make `extract_priors` take the `rng` as an argument * added `replace_values` for `Metadata` * make link and invlink act on the `metadata` field for `VarInfo` + implementations of these for `Metadata` and `VarNameVector` * added temporary defs of `with_logabsdet_jacobian` and `inverse` for `transpose` and `Bijectors.vec_to_triu` * added invlink_with_logpdf overload for `ThreadSafeVarInfo` * added `is_transformed` field to `VarNameVector` * removed unnecessary defintions of `with_logabsdet_jacobian` and `inverse` for `transpose` * fixed issue where we were storing the wrong transformations in `VarNameVector` * make sure `extract_priors` doesn't mutate the `varinfo` * updated `similar` for `VarNameVector` and fixed `invlink` for `VarNameVector` * added handling of `is_transformed` in `merge` for `VarNameVector` * removed unnecesasry `deepcopy` from outer `link` * updated `push!` to also `push!` on `is_transformed` * skip tests for mutating linking when using VarNameVector * use same projection for `Cholesky` in `VarNameVector` as in `VarInfo` * fixed `settrans!!` for `VarInfo` with `VarNameVector` * fixed bug in `set_flag!` * fixed another typo * fixed return values of `settrans!!` * updated static transformation tests * Update test/simple_varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * removed unnecessary impl of `extract_priors` * make `short_varinfo_name` of `TypedVarInfo` a bit more informative * moved impl of `has_varnamevector` for `ThreadSafeVarInfo` * added back `extract_priors` impl as we do need it * forgot to include tests for `VarNameVector` in `runtests.jl` * fix for `relax_container_types` in `test/varnamevector.jl` * fixed `need_transforms_relaxation` * updated some tests to not refer directly to `FromVec` * introduce `from_internal_transform` and its siblings * remove `with_logabsdet_jacobian_and_reconstruct` in favour of `with_logabsdet_jacobian` with `from_linked_internal_transform`, etc. * added `internal_to_linked_internal_transform` + fixed a few bugs in the linking as a resultt * added `linked_internal_to_internal_transform` as a complement to `interanl_to_linked_interanl_transform` * fixed bugs in `invlink` for `VarInfo` using `linked_internal_to_internal_transform` * more work on removing calls to `reconstruct` * removed redundant comment * added `from_linked_vec_transform` specialization for `LKJ` * more work on removing references to `reconstruct` * added `copy` in `values_from_metadata` to preserve behavior and avoid refs to internal representation * remove `reconstruct_and_link` and `invlink_and_reconstruct` * replaced references to `link_and_reconstruct` and `invlink_and_reconstruct` * introduced `recombine` and replaced calls to `reconstruct` with `n` samples * completely removed `reconstruct` * renamed `maybe_reconstruct_and_link` to `to_maybe_linked_internal` and `maybe_invlink_and_reconstruct` to `from_maybe_linked_internal` * added impls of `from_*_internal_transform` for `ThreadSafeVarInfo` * removed `reconstruct` from docs and from exports * renamed `getval` to `getindex_internal` and made `dist` an optional argument for all the transform-related methods * updated docs + added description of how internals of transforms work * added a bunch of illustrations for the transforms docs + dot files used to generated * temporarily removed `VarNameVector` completely * formatting * Update docs/src/internals/transformations.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update docs/src/internals/transformations.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * removed refs to VectorVarInfo * added impls of `from_internal_transform` for `ThreadSafeVarInfo` * reverted accidental removal of old `VarInfo` constructor * fixed incorrect `recombine` call * removed undefined refs to `VarNameVector` stuff in `setup_varinfos` * bump minior version because Turing breaks * fix: was using `from_linked_internal_transform` in `from_internal_transform` for `ThreadSafeVarInfo` * removed `getindex_raw` * removed redundant docstrings * fixed tests * fixed comparisons in tests * try relative references for images in transformation docs * another attempt at fixing asset-references * fixed getindex diagrams in docs * minor changes to comments * remove Combinatorics as a test dep, as it's not needed for this PR * reverted unnecessary change * disable type-stability tests for models on older Julia versions * removed seemingly completely unused impl of `setval!` * Revert "temporarily removed `VarNameVector` completely" This reverts commit 95dc8e3. * Revert "remove Combinatorics as a test dep, as it's not needed for this PR" This reverts commit 071bebf. * More work on `VarNameVector` (#637) * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Markus Hauru <[email protected]> * Type-stability tests are now correctly using `rand_prior_true` instead of `rand` * `getindex_internal` now calls `getindex` instead of `view`, as the latter can result in type-instability since transformed variables typically result in non-view even if input is a view * Removed seemingly unnecessary definition of `getindex_internal` * Fixed references to `newmetadata` which has been replaced by `replace_values` * Made implementation of `recombine` more explicit * Added docstrings for `untyped_varinfo` and `typed_varinfo` * Added TODO comment about implementing `view` for `VarInfo` * Fixed potential infinite recursion as suggested by @mhauru * added docstring to `from_vec_trnasform_for_size * Replaced references to `vectorize(dist, x)` with `tovec(x)` * Fixed docstring * Update src/extract_priors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Bump minor version since this is a breaking change * Apply suggestions from code review Co-authored-by: Markus Hauru <[email protected]> * Update src/varinfo.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Apply suggestions from code review * Apply suggestions from code review * Update src/extract_priors.jl Co-authored-by: Xianda Sun <[email protected]> * Added fix for product distributions of targets with changing support + tests * Addeed tests for product of distributions with dynamic support * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix typos, improve docstrings * Use Accessors rather than Setfield * Simplify group_by_symbol * Add short_varinfo_name(::VectorVarInfo) * Add tests for subset * Export VectorVarInfo * Tighter type bound for has_varnamevector * Add some VectorVarName methods * Add todo notes, remove dead code, fix a typo. * Bug fixes and small improvements * VarNameVector improvements * Improve generated_quantities and its tests * Improvement to VarNameVector * Fix a test to work with VectorVarName * Fix generated_quantities * Fix type stability issues * Various VarNameVector fixes and improvements * Bump version number * Improvements to generated_quantities * Code formatting * Code style * Add fallback implementation of findinds for VarNameVector * Rename VarNameVector to VarNamedVector * More renaming of VNV. Remove unused VarNamedVector.metadata field. * Rename FromVec to ReshapeTransform * Progress towards having VarNamedVector as storage for SimpleVarInfo * Fix unflatten(vnv::VarNamedVector, vals) * More work on SimpleVarInfo{VarNamedVector} * More tests for SimpleVarInfo{VarNamedVector} * More tests for SimpleVarInfo{VarNamedVector} * Respond to review feedback * Add float_type_with_fallback(::Type{Union{}}) * Move some VNV functions to the correct file * Fix push! for VNV * Rename VNV.is_transformed to VNV.is_unconstrained * Improve VNV docstring * Add VNV inner constructor checks * Reorganise parts of VNV code * Documentation and small fixes for VNV * Rename loosen_types!! and tighten_types, add docstrings and doctests * Rename VarNameVector to VarNamedVector in docs * Documentation and small fixes to VNV * Fix subset(::VarNamedVector, args...) for unconstrained variables. --------- Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun <[email protected]> * Bump Bijectors dependecy * Remove dead TODO note * Remove old TODOs, improve VNV invlinking * Fix from_vec_transform for 0-dim arrays * Fix unflatten for VarInfo * Fix some VarInfo index getters * Change how VNV handles transformations, and other VNV stuff * Small docs fixes * Small fixes all over for VNV * Add comments * Fix some tests * Change long string formatting to support Julia 1.6 * Small changes to ReshapeTransformation * Revert unrelated changes to ReverseDiff extension * Improve VarNamedVector VarInfo testing * Fix some depwarns * Improvements to test/simple_varinfo.jl * Fix for unset_flag!, better docstring * Add a comment about hasvalue/getvalue * Add @non_differentiable calls to work around Zygote limitations * Fix docs, workaround Zygote issue * Remove outdated workaround * Move has_varnamedvector(varinfo::VarInfo) to abstract_varinfo.jl * Make copies of logp and num_produce in subset * Rename getindex_raw to getindex_internal * Add push!(::VarNamedVector, ::Pair) * Improve VarNamedVector docs * Simplify VarNamedVector constructors * Change how VNV setindex! et al work * More improvements to VNV setters and their tests * Fix style issues in VNV * Update VNV docs. Add haskey to VarInfo * Fix VarInfo docs * Disable a test that only works for VectorVarInfo * Fix bug in isempty(::TypedVarInfo) * Make some doctests platform independent * Better implementation of haskey(::VarInfo, ::VarName) Co-authored-by: Tor Erlend Fjelde <[email protected]> * Improve haskey for VarInfo * Make a VNV doctest more robust * Remote IndexStyle for VNV * Clean up an old comment * Fix haskey(::VarInfo, ::VarName) * Clarify a TODO note in varinfo.jl * Reintroduce Int indexing to VNV * Stop exporting any VNV stuff * Fix docs * Revert default VarInfo metadata type to Metadata * Fix a few trivial issues with Metadata * Docs bug fix * Fix type constraint --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Xianda Sun <[email protected]>
1 parent 7f91c07 commit c38e65f

25 files changed

+3339
-262
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.29.2"
3+
version = "0.30"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -46,7 +46,7 @@ AbstractMCMC = "5"
4646
AbstractPPL = "0.8.4, 0.9"
4747
Accessors = "0.1"
4848
BangBang = "0.4.1"
49-
Bijectors = "0.13.9"
49+
Bijectors = "0.13.18"
5050
ChainRulesCore = "1"
5151
Compat = "4"
5252
ConstructionBase = "1.5.4"

docs/make.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ makedocs(;
2222
pages=[
2323
"Home" => "index.md",
2424
"API" => "api.md",
25-
"Internals" => ["internals/transformations.md"],
25+
"Internals" => ["internals/varinfo.md", "internals/transformations.md"],
2626
],
2727
checkdocs=:exports,
2828
doctest=false,

docs/src/api.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,19 @@ resetlogp!!
294294
```@docs
295295
keys
296296
getindex
297-
DynamicPPL.getindex_internal
298297
push!!
299298
empty!!
300299
isempty
300+
DynamicPPL.getindex_internal
301+
DynamicPPL.setindex_internal!
302+
DynamicPPL.update_internal!
303+
DynamicPPL.insert_internal!
304+
DynamicPPL.length_internal
305+
DynamicPPL.reset!
306+
DynamicPPL.update!
307+
DynamicPPL.insert!
308+
DynamicPPL.loosen_types!!
309+
DynamicPPL.tighten_types
301310
```
302311

303312
```@docs

docs/src/internals/varinfo.md

+302
Large diffs are not rendered by default.

ext/DynamicPPLChainRulesCoreExt.jl

+2
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@ ChainRulesCore.@non_differentiable DynamicPPL.updategid!(
2424
# No need + causes issues for some AD backends, e.g. Zygote.
2525
ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x)
2626

27+
ChainRulesCore.@non_differentiable DynamicPPL.recontiguify_ranges!(ranges)
28+
2729
end # module

ext/DynamicPPLMCMCChainsExt.jl

+137-6
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,152 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
"""
46+
generated_quantities(model::Model, chain::MCMCChains.Chains)
47+
48+
Execute `model` for each of the samples in `chain` and return an array of the values
49+
returned by the `model` for each sample.
50+
51+
# Examples
52+
## General
53+
Often you might have additional quantities computed inside the model that you want to
54+
inspect, e.g.
55+
```julia
56+
@model function demo(x)
57+
# sample and observe
58+
θ ~ Prior()
59+
x ~ Likelihood()
60+
return interesting_quantity(θ, x)
61+
end
62+
m = demo(data)
63+
chain = sample(m, alg, n)
64+
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
65+
# from the posterior/`chain`:
66+
generated_quantities(m, chain) # <= results in a `Vector` of returned values
67+
# from `interesting_quantity(θ, x)`
68+
```
69+
## Concrete (and simple)
70+
```julia
71+
julia> using DynamicPPL, Turing
72+
73+
julia> @model function demo(xs)
74+
s ~ InverseGamma(2, 3)
75+
m_shifted ~ Normal(10, √s)
76+
m = m_shifted - 10
77+
78+
for i in eachindex(xs)
79+
xs[i] ~ Normal(m, √s)
80+
end
81+
82+
return (m, )
83+
end
84+
demo (generic function with 1 method)
85+
86+
julia> model = demo(randn(10));
87+
88+
julia> chain = sample(model, MH(), 10);
89+
90+
julia> generated_quantities(model, chain)
91+
10×1 Array{Tuple{Float64},2}:
92+
(2.1964758025119338,)
93+
(2.1964758025119338,)
94+
(0.09270081916291417,)
95+
(0.09270081916291417,)
96+
(0.09270081916291417,)
97+
(0.09270081916291417,)
98+
(0.09270081916291417,)
99+
(0.043088571494005024,)
100+
(-0.16489786710222099,)
101+
(-0.16489786710222099,)
102+
```
103+
"""
45104
function DynamicPPL.generated_quantities(
46105
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
47106
)
48107
chain = MCMCChains.get_sections(chain_full, :parameters)
49108
varinfo = DynamicPPL.VarInfo(model)
50109
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
51110
return map(iters) do (sample_idx, chain_idx)
52-
# Update the varinfo with the current sample and make variables not present in `chain`
53-
# to be sampled.
54-
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
111+
if DynamicPPL.supports_varname_indexing(chain)
112+
varname_pairs = _varname_pairs_with_varname_indexing(
113+
chain, varinfo, sample_idx, chain_idx
114+
)
115+
else
116+
varname_pairs = _varname_pairs_without_varname_indexing(
117+
chain, varinfo, sample_idx, chain_idx
118+
)
119+
end
120+
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
121+
return fixed_model()
122+
end
123+
end
124+
125+
"""
126+
_varname_pairs_with_varname_indexing(
127+
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
128+
)
55129
56-
# TODO: Some of the variables can be a view into the `varinfo`, so we need to
57-
# `deepcopy` the `varinfo` before passing it to `model`.
58-
model(deepcopy(varinfo))
130+
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
131+
from the chain.
132+
133+
This implementation assumes `chain` can be indexed using variable names, and is the
134+
preffered implementation.
135+
"""
136+
function _varname_pairs_with_varname_indexing(
137+
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
138+
)
139+
vns = DynamicPPL.varnames(chain)
140+
vn_parents = Iterators.map(vns) do vn
141+
# The call nested_setindex_maybe! is used to handle cases where vn is not
142+
# the variable name used in the model, but rather subsumed by one. Except
143+
# for the subsumption part, this could be
144+
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
145+
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
146+
DynamicPPL.nested_setindex_maybe!(
147+
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
148+
)
59149
end
150+
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
151+
vn_parent => varinfo[vn_parent]
152+
end
153+
return varname_pairs
154+
end
155+
156+
"""
157+
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.
158+
159+
The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
160+
won't catch all cases. We should get rid of this if we can.
161+
"""
162+
# TODO(mhauru) See docstring above.
163+
function _vcat_subsumed_values(vn_string, values, key_strings)
164+
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
165+
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
166+
end
167+
168+
"""
169+
_varname_pairs_without_varname_indexing(
170+
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
171+
)
172+
173+
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
174+
from the chain.
175+
176+
This implementation does not assume that `chain` can be indexed using variable names. It is
177+
thus not guaranteed to work in cases where the variable names have complex subsumption
178+
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
179+
"""
180+
function _varname_pairs_without_varname_indexing(
181+
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
182+
)
183+
values = chain.value[sample_idx, :, chain_idx]
184+
keys = Base.keys(chain)
185+
keys_strings = map(string, keys)
186+
varname_pairs = [
187+
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
188+
vn in Base.keys(varinfo)
189+
]
190+
return varname_pairs
60191
end
61192

62193
end

src/DynamicPPL.jl

+1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ include("sampler.jl")
177177
include("varname.jl")
178178
include("distribution_wrappers.jl")
179179
include("contexts.jl")
180+
include("varnamedvector.jl")
180181
include("abstract_varinfo.jl")
181182
include("threadsafe.jl")
182183
include("varinfo.jl")

src/abstract_varinfo.jl

+12-5
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
295295
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
296296
297297
julia> # For the sake of brevity, let's just check the type.
298-
md = values_as(vi); md.s isa DynamicPPL.Metadata
298+
md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector}
299299
true
300300
301301
julia> values_as(vi, NamedTuple)
@@ -321,7 +321,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
321321
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
322322
323323
julia> # For the sake of brevity, let's just check the type.
324-
values_as(vi) isa DynamicPPL.Metadata
324+
values_as(vi) isa Union{DynamicPPL.Metadata, Vector}
325325
true
326326
327327
julia> values_as(vi, NamedTuple)
@@ -349,7 +349,7 @@ Determine the default `eltype` of the values returned by `vi[spl]`.
349349
This should generally not be called explicitly, as it's only used in
350350
[`matchingvalue`](@ref) to determine the default type to use in place of
351351
type-parameters passed to the model.
352-
352+
353353
This method is considered legacy, and is likely to be deprecated in the future.
354354
"""
355355
function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior})
@@ -363,6 +363,13 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP
363363
return eltype(T)
364364
end
365365

366+
"""
367+
has_varnamedvector(varinfo::VarInfo)
368+
369+
Returns `true` if `varinfo` uses `VarNamedVector` as metadata.
370+
"""
371+
has_varnamedvector(vi::AbstractVarInfo) = false
372+
366373
# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert
367374
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which
368375
# might result in a `Vector{Any}`.
@@ -554,7 +561,7 @@ end
554561
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
555562
link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
556563
557-
Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.
564+
Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`.
558565
559566
If `t` is not provided, `default_transformation(model, vi)` will be used.
560567
@@ -573,7 +580,7 @@ end
573580
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
574581
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
575582
576-
Transform the variables in `vi` to their constrained space, using the (inverse of)
583+
Transform the variables in `vi` to their constrained space, using the (inverse of)
577584
transformation `t`, mutating `vi` if possible.
578585
579586
If `t` is not provided, `default_transformation(model, vi)` will be used.

src/context_implementations.jl

+14-3
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,14 @@ function assume(
240240
if haskey(vi, vn)
241241
# Always overwrite the parameters with new ones for `SampleFromUniform`.
242242
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
243-
unset_flag!(vi, vn, "del")
243+
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
244+
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
245+
# if that's okay.
246+
unset_flag!(vi, vn, "del", true)
244247
r = init(rng, dist, sampler)
245248
f = to_maybe_linked_internal_transform(vi, vn, dist)
249+
# TODO(mhauru) This should probably be call a function called setindex_internal!
250+
# Also, if we use !! we shouldn't ignore the return value.
246251
BangBang.setindex!!(vi, f(r), vn)
247252
setorder!(vi, vn, get_num_produce(vi))
248253
else
@@ -516,7 +521,10 @@ function get_and_set_val!(
516521
if haskey(vi, vns[1])
517522
# Always overwrite the parameters with new ones for `SampleFromUniform`.
518523
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
519-
unset_flag!(vi, vns[1], "del")
524+
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
525+
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
526+
# that's okay.
527+
unset_flag!(vi, vns[1], "del", true)
520528
r = init(rng, dist, spl, n)
521529
for i in 1:n
522530
vn = vns[i]
@@ -554,7 +562,10 @@ function get_and_set_val!(
554562
if haskey(vi, vns[1])
555563
# Always overwrite the parameters with new ones for `SampleFromUniform`.
556564
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
557-
unset_flag!(vi, vns[1], "del")
565+
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
566+
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
567+
# that's okay.
568+
unset_flag!(vi, vns[1], "del", true)
558569
f = (vn, dist) -> init(rng, dist, spl)
559570
r = f.(vns, dists)
560571
for i in eachindex(vns)

0 commit comments

Comments
 (0)