Skip to content

ForwardDiff Overload Fixes #629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 53 additions & 23 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,57 @@ using ForwardDiff: Dual, Partials
using SciMLBase
using RecursiveArrayTools

const DualLinearProblem = LinearProblem{

# Define type for non-nested dual numbers
const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Float64 , P}

# Define type for nested dual numbers
const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P}

const SingleDualLinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing}, iip,
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
<:Any
} where {iip}

const NestedDualLinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{<:NestedDual, <:AbstractArray{<:NestedDual}},
<:Union{<:NestedDual, <:AbstractArray{<:NestedDual}},
<:Any
} where {iip, T, V, P}
} where {iip}

const DualALinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing},
iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
<:Union{Number, <:AbstractArray},
<:Any
} where {iip, T, V, P}
} where {iip}

const DualBLinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing},
iip,
<:Union{Number, <:AbstractArray},
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
<:Any
} where {iip, T, V, P}
} where {iip}

const DualAbstractLinearProblem = Union{
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}

LinearSolve.@concrete mutable struct DualLinearCache
linear_cache
dual_type

partials_A
partials_b
partials_u

dual_A
dual_b
dual_u
end

function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
Expand All @@ -55,16 +75,15 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa

rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)

partial_cache = cache.linear_cache
partial_cache.u = dual_u0

cache.linear_cache.u = dual_u0
# We can reuse the linear cache, because the same factorization will work for the partials.
for i in eachindex(rhs_list)
partial_cache.b = rhs_list[i]
rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u)
cache.linear_cache.b = rhs_list[i]
rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
end

# Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to
partial_cache.b = primal_b
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
cache.linear_cache.b = primal_b

partial_sols = rhs_list

Expand Down Expand Up @@ -144,7 +163,6 @@ function SciMLBase.init(
∂_A = partial_vals(A)
∂_b = partial_vals(b)

#primal_prob = LinearProblem(new_A, new_b, u0 = new_u0)
primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0)

if get_dual_type(prob.A) !== nothing
Expand All @@ -157,7 +175,7 @@ function SciMLBase.init(
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
sensealg = sensealg, u0 = new_u0, kwargs...)
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b)
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zero.(b))
end

function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
Expand All @@ -166,34 +184,45 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
cache::DualLinearCache, cache.alg, args...; kwargs...)

dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)

cache.dual_u = dual_sol

return SciMLBase.build_linear_solution(
cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats
)
end

# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
# Also "forwards" setproperty so that
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
# If the property is A or b, also update it in the LinearCache
if sym === :A || sym === :b || sym === :u
setproperty!(dc.linear_cache, sym, nodual_value(val))
elseif hasfield(DualLinearCache, sym)
setfield!(dc,sym,val)
elseif hasfield(LinearSolve.LinearCache, sym)
setproperty!(dc.linear_cache, sym, val)
end


# Update the partials if setting A or b
if sym === :A
setfield!(dc, :partials_A, partial_vals(val))
elseif sym === :b
elseif sym === :b
setfield!(dc, :partials_b, partial_vals(val))
else
setfield!(dc, sym, val)
elseif sym === :u
setfield!(dc, :partials_u, partial_vals(val))
end
end

# "Forwards" getproperty to LinearCache if necessary
function Base.getproperty(dc::DualLinearCache, sym::Symbol)
if hasfield(LinearSolve.LinearCache, sym)
if sym === :A
dc.dual_A
elseif sym === :b
dc.dual_b
elseif sym === :u
dc.dual_u
elseif hasfield(LinearSolve.LinearCache, sym)
return getproperty(dc.linear_cache, sym)
else
return getfield(dc, sym)
Expand Down Expand Up @@ -239,3 +268,4 @@ end


end

2 changes: 1 addition & 1 deletion test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ cache.b = new_b
x_p = solve!(cache)
backslash_x_p = A \ new_b

@test ≈(x_p, backslash_x_p, rtol = 1e-9)
@test ≈(x_p, backslash_x_p, rtol = 1e-9)
Loading