| 
1 | 1 | # assume  | 
2 |  | -function tilde_assume(context::AbstractContext, args...)  | 
3 |  | -    return tilde_assume(childcontext(context), args...)  | 
 | 2 | +function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi)  | 
 | 3 | +    return tilde_assume!!(childcontext(context), right, vn, vi)  | 
4 | 4 | end  | 
5 |  | -function tilde_assume(::DefaultContext, right, vn, vi)  | 
 | 5 | +function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi)  | 
6 | 6 |     y = getindex_internal(vi, vn)  | 
7 | 7 |     f = from_maybe_linked_internal_transform(vi, vn, right)  | 
8 | 8 |     x, inv_logjac = with_logabsdet_jacobian(f, y)  | 
9 | 9 |     vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right)  | 
10 | 10 |     return x, vi  | 
11 | 11 | end  | 
12 |  | -function tilde_assume(context::PrefixContext, right, vn, vi)  | 
 | 12 | +function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi)  | 
13 | 13 |     # Note that we can't use something like this here:  | 
14 | 14 |     #     new_vn = prefix(context, vn)  | 
15 |  | -    #     return tilde_assume(childcontext(context), right, new_vn, vi)  | 
 | 15 | +    #     return tilde_assume!!(childcontext(context), right, new_vn, vi)  | 
16 | 16 |     # This is because `prefix` applies _all_ prefixes in a given context to a  | 
17 | 17 |     # variable name. Thus, if we had two levels of nested prefixes e.g.  | 
18 | 18 |     # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the  | 
19 | 19 |     # first call would apply the prefix `a.b._`, and the recursive call  | 
20 | 20 |     # would apply the prefix `b._`, resulting in `b.a.b._`.  | 
21 | 21 |     # This is why we need a special function, `prefix_and_strip_contexts`.  | 
22 | 22 |     new_vn, new_context = prefix_and_strip_contexts(context, vn)  | 
23 |  | -    return tilde_assume(new_context, right, new_vn, vi)  | 
 | 23 | +    return tilde_assume!!(new_context, right, new_vn, vi)  | 
24 | 24 | end  | 
25 | 25 | 
 
  | 
26 | 26 | """  | 
27 | 27 |     tilde_assume!!(context, right, vn, vi)  | 
28 | 28 | 
  | 
29 | 29 | Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),  | 
30 | 30 | accumulate the log probability, and return the sampled value and updated `vi`.  | 
31 |  | -
  | 
32 |  | -By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log  | 
33 |  | -probability of `vi` with the returned value.  | 
34 | 31 | """  | 
35 |  | -function tilde_assume!!(context, right, vn, vi)  | 
36 |  | -    return if right isa DynamicPPL.Submodel  | 
37 |  | -        _evaluate!!(right, vi, context, vn)  | 
38 |  | -    else  | 
39 |  | -        tilde_assume(context, right, vn, vi)  | 
40 |  | -    end  | 
 | 32 | +function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi)  | 
 | 33 | +    return _evaluate!!(right, vi, context, vn)  | 
41 | 34 | end  | 
42 | 35 | 
 
  | 
43 | 36 | # observe  | 
 | 
0 commit comments