Skip to content

Are models always "constant" (from DI's perspective)? #856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
gdalle opened this issue Mar 22, 2025 · 8 comments
Open

Are models always "constant" (from DI's perspective)? #856

gdalle opened this issue Mar 22, 2025 · 8 comments

Comments

@gdalle
Copy link
Contributor

gdalle commented Mar 22, 2025

DI.Constant(model),

@penelopeysm I just had an afterthought about this: can it happen that users store differentiable data in a model, for instance if it closes over some cache?

@penelopeysm
Copy link
Member

Hmmmm. I don't think it's possible, because the things in a model that we want to differentiate with respect to are on the left-hand side of tilde statements, like x and y here:

using Turing

@model function f()
    x ~ Normal()
    y ~ Normal(x)
end

model = f()

However, models cannot capture values of tilde-lhs's from an outside scope (the macro ensures this). The only way to embed info about the value of tilde-lhs's in the model is to do something like this, which stores the value of x = 1 in the model itself

cmodel = model | (x=1,)

What this really does is to marginalise out x though, which means that we aren't interested in d(logp)/dx any more. So it's possible to store data in a model, but as far as I can tell, not differentiable data.

I've been trying for a while to come up with pathological ways of getting round these restrictions, but I couldn't really think of any way to do it. Happy to hear if you had some ideas on how to break Turing code though 😄

@penelopeysm
Copy link
Member

Hmmm, I realise I was only thinking of whether the model could be an active argument. I'll ponder the cache a bit as well.

@penelopeysm
Copy link
Member

penelopeysm commented Mar 22, 2025

😬 Something like this, perhaps? The model evaluator function (model.f) closes over d

julia> using DynamicPPL, Distributions, LogDensityProblems

julia> function g()
           d = [1.0]
           @model function g2()
               x ~ Normal()
               d[1] = x
               y ~ Normal(d[1])
           end
           return g2()
       end
g (generic function with 2 methods)

julia> model = g(); ldf = LogDensityFunction(model);

julia> LogDensityProblems.logdensity(ldf, randn(2))
-2.825520065854062

julia> ldf.model.f.d
1-element Vector{Float64}:
 1.0337255358589155

julia> LogDensityProblems.logdensity(ldf, randn(2))
-2.645539732162322

julia> ldf.model.f.d
1-element Vector{Float64}:
 1.1067816035197735

@gdalle
Copy link
Contributor Author

gdalle commented Mar 22, 2025

Yeah exactly, something like your last example should give the wrong result when used with Enzyme

@penelopeysm
Copy link
Member

penelopeysm commented Mar 23, 2025

I gave it a spin and it seems to be fine:

using DynamicPPL, Distributions, LogDensityProblems, ADTypes
using Mooncake, Enzyme
function g()
    d = [1.0]
    @model function g2()
        x ~ Normal()
        # see note below
        d[1] = x
        y ~ Normal(d[1])
    end
    return g2()
end
model = g()
params = [1.1777037635634742, -3.005045711130534]

for adtype in [
    AutoMooncake(; config=nothing),
    AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
    AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
]
    ldf = LogDensityFunction(model; adtype=adtype)
    logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, params)
    @show logp, grad
end

# Note: ForwardDiff and ReverseDiff both don't like the line `d[1] = x` because
# `x` is a tracked type and it can't be reassigned into a Vector{Float}. However,
# the model
#     @model function h()
#         x ~ Normal()
#         y ~ Normal(x)
#     end
# is the same as `g2` and running ForwardDiff and ReverseDiff on that
# simplified model gives the same results for `(logp, grad)` as below.

# Mooncake
# (logp, grad) = (-11.27906672779163, [-5.360453238257482, 4.182749474694008])

# Enzyme forward
# (logp, grad) = (-11.27906672779163, [-5.360453238257482, 4.182749474694008])

# Enzyme reverse
# (logp, grad) = (-11.27906672779163, [-5.360453238257482, 4.182749474694008])

Do you know if that's related to the set_runtime_activity calls? Without runtime activity enabled, basically every Turing model fails (even if all the Constants are changed to Cache), so it always has to be enabled (from our perspective).

@gdalle
Copy link
Contributor Author

gdalle commented Mar 28, 2025

I must admit I don't understand why this works with Enzyme.

@gdalle
Copy link
Contributor Author

gdalle commented Mar 31, 2025

@penelopeysm I took some time to come up with a minimum working example without DI or Turing, and I think you might have hit a bug in Enzyme with your demo? As in, it happens to work but it really shouldn't: EnzymeAD/Enzyme.jl#2344

In general, annotating models as DI.Constant will only be valid if they not close over storage which may contain active data. Otherwise, ForwardDiff will throw an error and Enzyme will silently return the wrong derivative.
Are users likely to encounter this pattern? Should it be forbidden in Turing?

@gdalle
Copy link
Contributor Author

gdalle commented Mar 31, 2025

It turns out this is actually an LLVM quirk 🤣

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants