diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index f2137eccb..6923c1bb2 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -7,37 +7,57 @@ using ForwardDiff: Dual, Partials using SciMLBase using RecursiveArrayTools -const DualLinearProblem = LinearProblem{ + +# Define type for non-nested dual numbers +const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Float64 , P} + +# Define type for nested dual numbers +const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P} + +const SingleDualLinearProblem = LinearProblem{ + <:Union{Number, <:AbstractArray, Nothing}, iip, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, + <:Any +} where {iip} + +const NestedDualLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}}, + <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}}, <:Any -} where {iip, T, V, P} +} where {iip} const DualALinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, <:Union{Number, <:AbstractArray}, <:Any -} where {iip, T, V, P} +} where {iip} const DualBLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, <:Union{Number, <:AbstractArray}, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, <:Any -} where {iip, T, V, P} +} where {iip} const DualAbstractLinearProblem = Union{ - DualLinearProblem, DualALinearProblem, DualBLinearProblem} + SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem} LinearSolve.@concrete mutable struct DualLinearCache linear_cache dual_type + partials_A partials_b + partials_u + + dual_A + dual_b + dual_u end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) @@ -55,16 +75,15 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - partial_cache = cache.linear_cache - partial_cache.u = dual_u0 - + cache.linear_cache.u = dual_u0 + # We can reuse the linear cache, because the same factorization will work for the partials. for i in eachindex(rhs_list) - partial_cache.b = rhs_list[i] - rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u) + cache.linear_cache.b = rhs_list[i] + rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u) end - # Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to - partial_cache.b = primal_b + # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to + cache.linear_cache.b = primal_b partial_sols = rhs_list @@ -144,7 +163,6 @@ function SciMLBase.init( ∂_A = partial_vals(A) ∂_b = partial_vals(b) - #primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0) if get_dual_type(prob.A) !== nothing @@ -157,7 +175,7 @@ function SciMLBase.init( primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) - return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b) + return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zero.(b)) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) @@ -166,34 +184,45 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) cache::DualLinearCache, cache.alg, args...; kwargs...) dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) + + cache.dual_u = dual_sol + return SciMLBase.build_linear_solution( cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats ) end # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache -# Also "forwards" setproperty so that function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache if sym === :A || sym === :b || sym === :u setproperty!(dc.linear_cache, sym, nodual_value(val)) + elseif hasfield(DualLinearCache, sym) + setfield!(dc,sym,val) elseif hasfield(LinearSolve.LinearCache, sym) setproperty!(dc.linear_cache, sym, val) end + # Update the partials if setting A or b if sym === :A setfield!(dc, :partials_A, partial_vals(val)) - elseif sym === :b + elseif sym === :b setfield!(dc, :partials_b, partial_vals(val)) - else - setfield!(dc, sym, val) + elseif sym === :u + setfield!(dc, :partials_u, partial_vals(val)) end end # "Forwards" getproperty to LinearCache if necessary function Base.getproperty(dc::DualLinearCache, sym::Symbol) - if hasfield(LinearSolve.LinearCache, sym) + if sym === :A + dc.dual_A + elseif sym === :b + dc.dual_b + elseif sym === :u + dc.dual_u + elseif hasfield(LinearSolve.LinearCache, sym) return getproperty(dc.linear_cache, sym) else return getfield(dc, sym) @@ -239,3 +268,4 @@ end end + diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index eb66c64dc..d73ec1746 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -79,4 +79,4 @@ cache.b = new_b x_p = solve!(cache) backslash_x_p = A \ new_b -@test ≈(x_p, backslash_x_p, rtol = 1e-9) \ No newline at end of file +@test ≈(x_p, backslash_x_p, rtol = 1e-9)