Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic reverse mode #129

Merged
merged 17 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
StochasticADEnzymeExt = "Enzyme"

[compat]
ChainRulesCore = "1.15"
ChainRulesOverloadGeneration = "0.1"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac"
12 changes: 10 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using Pkg

using Documenter, StochasticAD, DocThemeIndigo
using Documenter
using StochasticAD
using DocThemeIndigo
using Literate

### Formatting

Expand All @@ -18,13 +21,18 @@ pages = [
"tutorials/random_walk.md",
"tutorials/game_of_life.md",
"tutorials/particle_filter.md",
"tutorials/optimizations.md"
"tutorials/optimizations.md",
"tutorials/reverse_demo.md"
],
"Public API" => "public_api.md",
"Developer documentation" => "devdocs.md",
"Limitations" => "limitations.md"
]

### Prepare literate tutorials

# TODO (for now they are manually built into docs/src/tutorials and checked into repo)

### Make docs

makedocs(sitename = "StochasticAD.jl",
Expand Down
9 changes: 7 additions & 2 deletions docs/src/public_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ of standard AD, where derivatives of discrete random steps are dropped:
StochasticAD.dual_number
```

## Algorithms

```@docs
StochasticAD.ForwardAlgorithm
StochasticAD.EnzymeReverseAlgorithm
```

## Smoothing

What happens if we were to run [`derivative_contribution`](@ref) after each step, instead of only at the end? This is *smoothing*, which combines the second and third components of a single stochastic triple into a single dual component.
Expand All @@ -27,8 +34,6 @@ Forward smoothing rules are provided through `ForwardDiff`, and backward rules t
Currently, special discrete->discrete constructs such as array indexing are not supported for smoothing.




## Optimization

We also provide utilities to make it easier to get started with forming and training a model via stochastic gradient descent:
Expand Down
78 changes: 78 additions & 0 deletions docs/src/tutorials/reverse_demo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
```@meta
EditURL = "../../../tutorials/reverse_example/reverse_demo.jl"
```

# Simple reverse mode example

```@setup random_walk
import Pkg
Pkg.activate("../../../tutorials")
Pkg.develop(path="../../..")
Pkg.instantiate()

import Random
Random.seed!(1234)
```

Load our packages

````@example reverse_demo
using StochasticAD
using Distributions
using Enzyme
using LinearAlgebra
````

Let us define our target function.

````@example reverse_demo
# Define a toy `StochasticAD`-differentiable function for computing an integer value from a string.
string_value(strings, index) = Int(sum(codepoint, strings[index]))
string_value(strings, index::StochasticTriple) = StochasticAD.propagate(index -> string_value(strings, index), index)

function f(θ; derivative_coupling = StochasticAD.InversionMethodDerivativeCoupling())
strings = ["cat", "dog", "meow", "woofs"]
index = randst(Categorical(θ); derivative_coupling)
return string_value(strings, index)
end

θ = [0.1, 0.5, 0.3, 0.1]
@show f(θ)
nothing
````

First, let's compute the sensitivity of `f` in a particular direction via forward-mode Stochastic AD.

````@example reverse_demo
u = [1.0, 2.0, 4.0, -7.0]
@show derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
nothing
````

Now, let's do the same with reverse-mode.

````@example reverse_demo
@show derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))
````

Let's verify that our reverse-mode gradient is consistent with our forward-mode directional derivative.

````@example reverse_demo
forward() = derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
reverse() = derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))

N = 40000
directional_derivs_fwd = [forward() for i in 1:N]
derivs_bwd = [reverse() for i in 1:N]
directional_derivs_bwd = [dot(u, δ) for δ in derivs_bwd]
println("Forward mode: $(mean(directional_derivs_fwd)) ± $(std(directional_derivs_fwd) / sqrt(N))")
println("Reverse mode: $(mean(directional_derivs_bwd)) ± $(std(directional_derivs_bwd) / sqrt(N))")
@assert isapprox(mean(directional_derivs_fwd), mean(directional_derivs_bwd), rtol = 3e-2)

nothing
````

---

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*

38 changes: 38 additions & 0 deletions ext/StochasticADEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module StochasticADEnzymeExt

using StochasticAD
using Enzyme

function enzyme_target(u, X, p, backend)
# equivalent to derivative_estimate(X, p; backend, direction = u), but specialize to real output to make Enzyme happier
st = StochasticAD.stochastic_triple_direction(X, p, u; backend)
if !(StochasticAD.valtype(st) <: Real)
error("EnzymeReverseAlgorithm only supports real-valued outputs.")
end
return derivative_contribution(st)
end

function StochasticAD.derivative_estimate(X, p, alg::StochasticAD.EnzymeReverseAlgorithm;
direction = nothing, alg_data = (; forward_u = nothing))
if !isnothing(direction)
error("EnzymeReverseAlgorithm does not support keyword argument `direction`")
end
if p isa AbstractVector
Δu = zeros(float(eltype(p)), length(p))
u = isnothing(alg_data.forward_u) ?
rand(StochasticAD.RNG, float(eltype(p)), length(p)) : alg_data.forward_u
autodiff(Enzyme.Reverse, enzyme_target, Active, Duplicated(u, Δu),
Const(X), Const(p), Const(alg.backend))
return Δu
elseif p isa Real
u = isnothing(alg_data.forward_u) ? rand(StochasticAD.RNG, float(typeof(p))) :
forward_u
((du, _, _, _),) = autodiff(Enzyme.Reverse, enzyme_target, Active, Active(u),
Const(X), Const(p), Const(alg.backend))
return du
else
error("EnzymeReverseAlgorithm only supports p::Real or p::AbstractVector")
end
end

end
1 change: 1 addition & 0 deletions src/StochasticAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ include("stochastic_triple.jl") # Defines stochastic triple object and higher le
include("general_rules.jl") # Defines rules for propagation through deterministic functions
include("discrete_randomness.jl") # Defines rules for propagation through discrete random functions
include("propagate.jl") # Experimental generalized forward propagation functionality
include("algorithms.jl") # Add algorithm-based higher-level interface
include("misc.jl") # Miscellaneous functions that do not fit in the usual flow

end
96 changes: 96 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
abstract type AbstractStochasticADAlgorithm end

"""
ForwardAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm

A differentiation algorithm relying on forward propagation of stochastic triples.

The `backend` argument controls the algorithm used by the third component of the stochastic triples.

!!! note
The required computation time for forward-mode AD scales linearly with the number of
parameters in `p` (but is unaffected by the number of parameters in `X(p)`).
"""
struct ForwardAlgorithm{B <: StochasticAD.AbstractFIsBackend} <:
AbstractStochasticADAlgorithm
backend::B
end

"""
EnzymeReverseAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm

A differentiation algorithm relying on transposing the propagation of stochastic triples to
produce a reverse-mode algorithm. The transposition is performed by [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl),
which must be loaded for the algorithm to run.

Currently, only real- and vector-valued inputs are supported, and only real-valued outputs are supported.

The `backend` argument controls the algorithm used by the third component of the stochastic triples.

In the call to `derivative_estimate`, this algorithm optionally accepts `alg_data` with the field `forward_u`,
which specifies the directional derivative used in the forward pass that will be transposed.
If `forward_u` is not provided, it is randomly generated.

!!! warning
For the reverse-mode algorithm to yield correct results, the employed `backend` cannot use input-dependent pruning
strategies. A suggested reverse-mode compatible backend is `PrunedFIsBackend(Val(:wins))`.

Additionally, this algorithm relies on the ability of `Enzyme.jl` to differentiate the forward stochastic triple run.
It is recommended to check that the primal function `X` is type stable for its input `p` using a tool such as
[JET.jl](https://github.com/aviatesk/JET.jl), with all code executed in a function with no global state.
In addition, sometimes `X` may be type stable but stochastic triples introduce additional type stabilities.
This can be debugged by checking type stability of Enzyme's target, which is
`Base.get_extension(StochasticAD, :StochasticADEnzymeExt).enzyme_target(u, X, p, backend)`,
where `u` is a test direction.

!!! note
For more details on the reverse-mode approach, see the following papers and talks:

* ["You Only Linearize Once: Tangents Transpose to Gradients"](https://arxiv.org/abs/2204.10923), Radul et al. 2022.
* ["Reverse mode ADEV via YOLO: tangent estimators transpose to gradient estimators"](https://www.youtube.com/watch?v=pnPmk-leSsE)), Becker et al. 2024
* ["Probabilistic Programming with Programmable Variational Inference"](https://pldi24.sigplan.org/details/pldi-2024-papers/87/Probabilistic-Programming-with-Programmable-Variational-Inference), Becker et al. 2024
"""
struct EnzymeReverseAlgorithm{B <: StochasticAD.AbstractFIsBackend}
backend::B
end

function derivative_estimate(
X, p, alg::ForwardAlgorithm; direction = nothing, alg_data::NamedTuple = (;))
return derivative_estimate(X, p; backend = alg.backend, direction)
end

@doc raw"""
derivative_estimate(X, p, alg::AbstractStochasticADAlgorithm = ForwardAlgorithm(PrunedFIsBackend()); direction=nothing, alg_data::NamedTuple = (;))

Compute an unbiased estimate of ``\frac{\mathrm{d}\mathbb{E}[X(p)]}{\mathrm{d}p}``,
the derivative of the expectation of the random function `X(p)` with respect to its input `p`.

Both `p` and `X(p)` can be any object supported by [`Functors.jl`](https://fluxml.ai/Functors.jl/stable/),
e.g. scalars or abstract arrays.
The output of `derivative_estimate` has the same outer structure as `p`, but with each
scalar in `p` replaced by a derivative estimate of `X(p)` with respect to that entry.
For example, if `X(p) <: AbstractMatrix` and `p <: Real`, then the output would be a matrix.

The `alg` keyword argument specifies the [algorithm](public_api.md#Algorithms) used to compute the derivative estimate.
For backward compatibility, an additional signature `derivative_estimate(X, p; backend, direction=nothing)`
is supported, which uses `ForwardAlgorithm` by default with the supplied `backend.`
The `alg_data` keyword argument can specify any additional data that specific algorithms accept or require.

When `direction` is provided, the output is only differentiated with respect to a perturbation
of `p` in that direction.

# Example
```jldoctest
julia> using Distributions, Random, StochasticAD; Random.seed!(4321);

julia> derivative_estimate(rand ∘ Bernoulli, 0.5) # A random quantity that averages to the true derivative.
2.0

julia> derivative_estimate(x -> [rand(Bernoulli(x * i/4)) for i in 1:3], 0.5)
3-element Vector{Float64}:
0.2857142857142857
0.6666666666666666
0.0
```
"""
derivative_estimate
Loading
Loading