-
Notifications
You must be signed in to change notification settings - Fork 35
Are models always "constant" (from DI's perspective)? #856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hmmmm. I don't think it's possible, because the things in a model that we want to differentiate with respect to are on the left-hand side of tilde statements, like using Turing
@model function f()
x ~ Normal()
y ~ Normal(x)
end
model = f() However, models cannot capture values of tilde-lhs's from an outside scope (the macro ensures this). The only way to embed info about the value of tilde-lhs's in the model is to do something like this, which stores the value of cmodel = model | (x=1,) What this really does is to marginalise out I've been trying for a while to come up with pathological ways of getting round these restrictions, but I couldn't really think of any way to do it. Happy to hear if you had some ideas on how to break Turing code though 😄 |
Hmmm, I realise I was only thinking of whether the model could be an active argument. I'll ponder the cache a bit as well. |
😬 Something like this, perhaps? The model evaluator function ( julia> using DynamicPPL, Distributions, LogDensityProblems
julia> function g()
d = [1.0]
@model function g2()
x ~ Normal()
d[1] = x
y ~ Normal(d[1])
end
return g2()
end
g (generic function with 2 methods)
julia> model = g(); ldf = LogDensityFunction(model);
julia> LogDensityProblems.logdensity(ldf, randn(2))
-2.825520065854062
julia> ldf.model.f.d
1-element Vector{Float64}:
1.0337255358589155
julia> LogDensityProblems.logdensity(ldf, randn(2))
-2.645539732162322
julia> ldf.model.f.d
1-element Vector{Float64}:
1.1067816035197735 |
Yeah exactly, something like your last example should give the wrong result when used with Enzyme |
I gave it a spin and it seems to be fine: using DynamicPPL, Distributions, LogDensityProblems, ADTypes
using Mooncake, Enzyme
function g()
d = [1.0]
@model function g2()
x ~ Normal()
# see note below
d[1] = x
y ~ Normal(d[1])
end
return g2()
end
model = g()
params = [1.1777037635634742, -3.005045711130534]
for adtype in [
AutoMooncake(; config=nothing),
AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
]
ldf = LogDensityFunction(model; adtype=adtype)
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, params)
@show logp, grad
end
# Note: ForwardDiff and ReverseDiff both don't like the line `d[1] = x` because
# `x` is a tracked type and it can't be reassigned into a Vector{Float}. However,
# the model
# @model function h()
# x ~ Normal()
# y ~ Normal(x)
# end
# is the same as `g2` and running ForwardDiff and ReverseDiff on that
# simplified model gives the same results for `(logp, grad)` as below.
# Mooncake
# (logp, grad) = (-11.27906672779163, [-5.360453238257482, 4.182749474694008])
# Enzyme forward
# (logp, grad) = (-11.27906672779163, [-5.360453238257482, 4.182749474694008])
# Enzyme reverse
# (logp, grad) = (-11.27906672779163, [-5.360453238257482, 4.182749474694008]) Do you know if that's related to the |
I must admit I don't understand why this works with Enzyme. |
@penelopeysm I took some time to come up with a minimum working example without DI or Turing, and I think you might have hit a bug in Enzyme with your demo? As in, it happens to work but it really shouldn't: EnzymeAD/Enzyme.jl#2344 In general, annotating models as |
It turns out this is actually an LLVM quirk 🤣 |
DynamicPPL.jl/src/logdensityfunction.jl
Line 135 in 0810e14
@penelopeysm I just had an afterthought about this: can it happen that users store differentiable data in a model, for instance if it closes over some cache?
The text was updated successfully, but these errors were encountered: