Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 107 additions & 6 deletions ext/LinearSolveMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
module LinearSolveMooncakeExt

using Mooncake
using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!!
using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!!, @is_primitive, primal, zero_fcodual, CoDual, rdata, fdata
using LinearSolve: LinearSolve, SciMLLinearSolveAlgorithm, init, solve!, LinearProblem,
LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver,
defaultalg_adjoint_eval, solve
LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver, LinearSolveAdjoint,
defaultalg_adjoint_eval, solve, LUFactorization
using LinearSolve.LinearAlgebra
using LazyArrays: @~, BroadcastArray
using SciMLBase

@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve), LinearProblem, Nothing} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve),LinearProblem,Nothing} true ReverseMode
@from_chainrules MinimalCtx Tuple{
typeof(SciMLBase.solve), LinearProblem, SciMLLinearSolveAlgorithm} true ReverseMode
typeof(SciMLBase.solve),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{
Type{<:LinearProblem}, AbstractMatrix, AbstractVector, SciMLBase.NullParameters} true ReverseMode
Type{<:LinearProblem},AbstractMatrix,AbstractVector,SciMLBase.NullParameters} true ReverseMode

function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearProblem)
f.data.A .+= t.A
Expand All @@ -29,4 +30,104 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T}
end
end

function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache)
f.fields.A .+= t.A
f.fields.b .+= t.b
f.fields.u .+= t.u

return NoRData()
end

# rrules for LinearCache
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode

# rrules for solve!
# NOTE - Avoid Mooncake.prepare_gradient_cache, only use Mooncake.prepare_pullback_cache (and therefore Mooncake.value_and_pullback!!)
# calling Mooncake.prepare_gradient_cache for functions with solve! will activate unsupported Adjoint case exception for below rrules
# This because in Mooncake.prepare_gradient_cache we reset stacks + state by passing in zero gradient in the reverse pass once.
# However, if one has a valid cache then they can directly use Mooncake.value_and_gradient!!.

@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm,Vararg}
@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing,Vararg}

function Mooncake.rrule!!(sig::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{Nothing}, args::Vararg{Any,N}; kwargs...) where {N}
cache = primal(_cache)
assump = OperatorAssumptions()
_alg.x = defaultalg(cache.A, cache.b, assump)
Mooncake.rrule!!(sig, _cache, _alg, args...; kwargs...)
end

function Mooncake.rrule!!(::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{<:SciMLLinearSolveAlgorithm}, args::Vararg{Any,N}; alias_A=zero_fcodual(LinearSolve.default_alias_A(
_alg.x, _cache.x.A, _cache.x.b)), kwargs...) where {N}

cache = primal(_cache)
alg = primal(_alg)
_args = map(primal, args)

(; A, b, sensealg) = cache
A_orig = copy(A)
b_orig = copy(b)

@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."

# logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve
if sensealg.linsolve === missing
if !(alg isa LinearSolve.AbstractFactorization || alg isa LinearSolve.AbstractKrylovSubspaceMethod ||
alg isa LinearSolve.DefaultLinearSolver)
A_ = alias_A ? deepcopy(A) : A
end
else
A_ = deepcopy(A)
end

sol = zero_fcodual(solve!(cache))
cache.A = A_orig
cache.b = b_orig

function solve!_adjoint(::NoRData)
∂∅ = NoRData()
cachenew = init(LinearProblem(cache.A, cache.b), LUFactorization(), _args...; kwargs...)
new_sol = solve!(cachenew)
∂u = sol.dx.data.u

if sensealg.linsolve === missing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(adjoint(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
λ = solve(
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

tu = adjoint(new_sol.u)
∂A = BroadcastArray(@~ .-(λ .* tu))
∂b = λ

if (iszero(∂b) || iszero(∂A)) && !iszero(tu)
error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.")
end

fdata(_cache.dx).fields.A .+= ∂A
fdata(_cache.dx).fields.b .+= ∂b
fdata(_cache.dx).fields.u .+= ∂u

# rdata for cache is a struct with NoRdata field values
return (∂∅, rdata(_cache.dx), ∂∅, ntuple(_ -> ∂∅, length(args))...)
end

return sol, solve!_adjoint
end

end
16 changes: 16 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,19 @@ function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
return prob, ∇prob
end

function CRC.rrule(T::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Nothing, args...; kwargs...)
assump = OperatorAssumptions(issquare(prob.A))
alg = defaultalg(prob.A, prob.b, assump)
CRC.rrule(T, prob, alg, args...; kwargs...)
end

function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Union{LinearSolve.SciMLLinearSolveAlgorithm,Nothing}, args...; kwargs...)
init_res = LinearSolve.init(prob, alg)
function init_adjoint(∂init)
∂prob = LinearProblem(∂init.A, ∂init.b, NoTangent())
return NoTangent(), ∂prob, NoTangent(), ntuple((_ -> NoTangent(), length(args))...)
end

return init_res, init_adjoint
end
145 changes: 143 additions & 2 deletions test/nopre/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ b1 = rand(n);

function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
norm(s1)
end
Expand Down Expand Up @@ -153,3 +151,146 @@ for alg in (
@test results[1] ≈ fA(A)
@test mooncake_gradient ≈ fd_jac rtol = 1e-5
end

# Tests for solve! and init rrules.
n = 4
A = rand(n, n);
b1 = rand(n);
b2 = rand(n);

function f_(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f_(copy(A), copy(b1), copy(b2))
rule = Mooncake.build_rrule(f_, copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_pullback!!(
rule, 1.0,
f_, copy(A), copy(b1), copy(b2)
)

dA2 = ForwardDiff.gradient(x -> f_(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x -> f_(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x -> f_(eltype(x).(A), eltype(x).(b1), x), copy(b2))

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f_2(A, b1, b2; alg=RFLUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f_2(copy(A), copy(b1), copy(b2))
rule = Mooncake.build_rrule(f_2, copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_pullback!!(
rule, 1.0,
f_2, copy(A), copy(b1), copy(b2)
)

dA2 = ForwardDiff.gradient(x -> f_2(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), eltype(x).(b1), x), copy(b2))

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f_3(A, b1, b2; alg=KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f_3(copy(A), copy(b1), copy(b2))
rule = Mooncake.build_rrule(f_3, copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_pullback!!(
rule, 1.0,
f_3, copy(A), copy(b1), copy(b2)
)

dA2 = ForwardDiff.gradient(x -> f_3(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), eltype(x).(b1), x), copy(b2))

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f_4(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
solve!(cache)
s1 = copy(cache.u)
cache.b = b2
solve!(cache)
s2 = copy(cache.u)
norm(s1 + s2)
end

A = rand(n, n);
b1 = rand(n);
b2 = rand(n);
f_primal = f_4(copy(A), copy(b1), copy(b2))

rule = Mooncake.build_rrule(f_4, copy(A), copy(b1), copy(b2))
@test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!(
rule, 1.0,
f_4, copy(A), copy(b1), copy(b2)
)

# dA2 = ForwardDiff.gradient(x -> f_4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
# db12 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
# db22 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), eltype(x).(b1), x), copy(b2))

# @test value == f_primal
# @test grad[2] ≈ dA2
# @test grad[3] ≈ db12
# @test grad[4] ≈ db22

A = rand(n, n);
b1 = rand(n);

function fnice(A, b, alg)
prob = LinearProblem(A, b)
sol1 = solve(prob, alg)
return sum(sol1.u)
end

@testset for alg in (
LUFactorization(),
RFLUFactorization(),
KrylovJL_GMRES()
)
# for B
fb_closure = b -> fnice(A, b, alg)
fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec

val, en_jac = Mooncake.value_and_gradient!!(
prepare_gradient_cache(fnice, copy(A), copy(b1), alg),
fnice, copy(A), copy(b1), alg
)
@test en_jac[3] ≈ fd_jac_b rtol = 1e-5

# For A
fA_closure = A -> fnice(A, b1, alg)
fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
A_grad = en_jac[2] |> vec
@test A_grad ≈ fd_jac_A rtol = 1e-5
end
Loading