-
Notifications
You must be signed in to change notification settings - Fork 227
AdvancedPS v0.7 (and thus Libtask v0.9) support #2585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The tests that I had the patience to run locally now pass. Waiting for the AdvancedPS release to be able to run the full test suite on CI. Some indicators of speed: julia> module MWE
using Turing
@model function gdemo(x, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x ~ Normal(m, sqrt(s))
y ~ Normal(m, sqrt(s))
return s, m
end
@time chn = sample(gdemo(2.5, 1.0), PG(10), 10_000)
describe(chn)
end On main:
On this branch:
julia> module MWE
using Turing
@model function f(dim=20, ::Type{T}=Float64) where T
s = Vector{Bool}(undef, dim)
x = Vector{T}(undef, dim)
for i in 1:dim
s[i] ~ Bernoulli()
if s[i]
x[i] ~ Normal()
else
x[i] ~ Beta()
end
0.0 ~ Normal(x[i])
end
return nothing
end
alg = Gibbs(
@varname(s)=>PG(10),
@varname(x)=>HMC(0.1, 5),
)
@time chn = sample(f(), alg, 1_000)
end On main:
On this branch:
Obviously the speed gains are all due to @willtebbutt's fantastic work on Libtask, everything else is just wrapping that work. |
Turing.jl documentation for PR #2585 is available at: |
@@ -402,11 +391,11 @@ end | |||
|
|||
function trace_local_varinfo_maybe(varinfo) | |||
try | |||
trace = AdvancedPS.current_trace() | |||
return trace.model.f.varinfo | |||
trace = Libtask.get_taped_globals(Any).other |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we change Libtask.get_taped_globals
to return nothing
if not inside a running TapedTask
, the following try .. catch ... end
can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Libtask would then lose a distinction between having nothing
as the taped global within a task, and just not being within a task at all. I wonder if that distinction could be useful in some situations.
@@ -416,11 +405,10 @@ end | |||
|
|||
function trace_local_rng_maybe(rng::Random.AbstractRNG) | |||
try | |||
trace = AdvancedPS.current_trace() | |||
return trace.rng | |||
return Libtask.get_taped_globals(Any).rng |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same with above.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2585 +/- ##
===========================================
+ Coverage 58.50% 81.44% +22.94%
===========================================
Files 22 22
Lines 1458 1466 +8
===========================================
+ Hits 853 1194 +341
+ Misses 605 272 -333 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Test Coverage Report for Build 16350159377Details
💛 - Coveralls |
Is this reviewable? The tests are failing, there's a method ambiguity that Aqua complains about, there's a Gibbs failure on 1.12 which should be disabled with
I don't want to speak for @mhauru in his absence but last time we spoke about this PR, it was clear that there were still a few gaps to bridge. If I were to review it at this stage, my sole comment would be to fix the tests. |
The remaining test failures on v1.10 and v1.11 should be fixed once TuringLang/Libtask.jl#192 and TuringLang/Libtask.jl#191 are merged and released (I see the tests passing locally). Will then have a look at what's going on with v1.12, and check code quality, and hopefully this could then be done. |
Co-authored-by: Hong Ge <[email protected]>
Seems I spoke too soon. All tests now pass on some version of Julia. But the beta binomial test still fails on v1.11, whereas the ESS test fails on v1.10. I'll try to reproduce locally, although no immediate success. |
function AdvancedPS.update_rng!( | ||
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} | ||
) | ||
# Extract the `args`. | ||
args = trace.model.ctask.args | ||
# From `args`, extract the `SamplingContext`, which contains the RNG. | ||
sampling_context = args[3] | ||
rng = sampling_context.rng | ||
trace.rng = rng | ||
return trace | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was used to keep the internal state of trace
consistent, and should now be taken care of in AdvancedPS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my edification: AdvancedPS uses its own RNG, right? Is that distinct from the RNG in the SamplingContext (that was used here) / how are they related?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know, but I don't think they are always the same. AdvancedPS does all sorts of seed splitting stuff to make sure each trace has its own source of randomness. For this purpose it also uses a particular RNG type called Random123
. Frankly not sure how that will relate to the one in the context.
@@ -481,6 +469,25 @@ function AdvancedPS.Trace( | |||
|
|||
tmodel = TracedModel(model, sampler, newvarinfo, rng) | |||
newtrace = AdvancedPS.Trace(tmodel, rng) | |||
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was about keeping internal state of newtrace
consistent, and is now dealt with in the AdvancedPS.Trace
constructor.
@penelopeysm, this is ready for attention now. The only test failures are Libtask v1.12 issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding correctness: I have tested quite extensively using StableRNGs on both 1.10.10 and 1.11.6 and this PR makes no difference to the sampled values. I can't figure out why the tests needed increased sample counts, but I at least feel reasonably confident that it's not because of a correctness problem.
Happy to merge whenever you like, just had one question about the rng which is more for my understanding, and it shouldn't block this PR.
Thanks for testing that, I find that reassuring. I'll try to see if Libtask on v1.12 is an easy fix, and then either get that working or merge this as-is. Waaaait, what, one of the v1.10 CI runs just segfaulted. Ewww. |
Passed after rerunning CI. Do we have an indeterministic test, or did we get unlucky and GHA bugged out? |
I would be happy to ignore it, unless and until it shows up again. |
The complement PR of TuringLang/AdvancedPS.jl#114, which adds support for the newly rewritten Libtask.
Work in progress, currently blocked by TuringLang/Libtask.jl#186