From 00d9f4c873afc3b731031a1efa676b3e0dad7ad4 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 2 Jul 2025 09:58:19 -0400 Subject: [PATCH 1/5] don't set properties again --- ext/LinearSolveForwardDiffExt.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index f2137eccb..40bbff9f4 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -186,8 +186,6 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) setfield!(dc, :partials_A, partial_vals(val)) elseif sym === :b setfield!(dc, :partials_b, partial_vals(val)) - else - setfield!(dc, sym, val) end end From f672e55d0994252c05c972c8ab3201bd47dea2d4 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 2 Jul 2025 14:34:31 -0400 Subject: [PATCH 2/5] make sure that when A, b, or u are accessed you get the Dual numbers --- ext/LinearSolveForwardDiffExt.jl | 41 ++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 40bbff9f4..3f56528e6 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -36,8 +36,14 @@ const DualAbstractLinearProblem = Union{ LinearSolve.@concrete mutable struct DualLinearCache linear_cache dual_type + partials_A partials_b + partials_u + + dual_A + dual_b + dual_u end function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) @@ -55,16 +61,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) - partial_cache = cache.linear_cache - partial_cache.u = dual_u0 - + cache.linear_cache.u = dual_u0 + # We can reuse the linear cache, because the same factorization will work for the partials. for i in eachindex(rhs_list) - partial_cache.b = rhs_list[i] - rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u) + cache.linear_cache.b = rhs_list[i] + rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u) end - # Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to - partial_cache.b = primal_b + # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to + cache.linear_cache.b = primal_b + Main.@infiltrate partial_sols = rhs_list @@ -144,7 +150,6 @@ function SciMLBase.init( ∂_A = partial_vals(A) ∂_b = partial_vals(b) - #primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0) if get_dual_type(prob.A) !== nothing @@ -157,7 +162,7 @@ function SciMLBase.init( primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) - return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b) + return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zero.(b)) end function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) @@ -166,13 +171,16 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) cache::DualLinearCache, cache.alg, args...; kwargs...) dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) + Main.@infiltrate + + cache.dual_u = dual_sol + Main.@infiltrate return SciMLBase.build_linear_solution( cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats ) end # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache -# Also "forwards" setproperty so that function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache if sym === :A || sym === :b || sym === :u @@ -184,14 +192,23 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # Update the partials if setting A or b if sym === :A setfield!(dc, :partials_A, partial_vals(val)) - elseif sym === :b + elseif sym === :b setfield!(dc, :partials_b, partial_vals(val)) + elseif sym === :u + Main.@infiltrate + setfield!(dc, :partials_u, partial_vals(val)) end end # "Forwards" getproperty to LinearCache if necessary function Base.getproperty(dc::DualLinearCache, sym::Symbol) - if hasfield(LinearSolve.LinearCache, sym) + if sym === :A + dc.dual_A + elseif sym === :b + dc.dual_b + elseif sym === :u + dc.dual_u + elseif hasfield(LinearSolve.LinearCache, sym) return getproperty(dc.linear_cache, sym) else return getfield(dc, sym) From 9f81fc9d83b370927f3e83e53aa9ffcad7789438 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 2 Jul 2025 14:49:17 -0400 Subject: [PATCH 3/5] make sure cache u is updated --- ext/LinearSolveForwardDiffExt.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 3f56528e6..9197d01b2 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -185,9 +185,12 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache if sym === :A || sym === :b || sym === :u setproperty!(dc.linear_cache, sym, nodual_value(val)) + elseif hasfield(DualLinearCache, sym) + setfield!(dc,sym,val) elseif hasfield(LinearSolve.LinearCache, sym) setproperty!(dc.linear_cache, sym, val) end + # Update the partials if setting A or b if sym === :A @@ -195,7 +198,6 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) elseif sym === :b setfield!(dc, :partials_b, partial_vals(val)) elseif sym === :u - Main.@infiltrate setfield!(dc, :partials_u, partial_vals(val)) end end @@ -254,3 +256,4 @@ end end + From ed77d30518ed3c41b9c649e8f3e1435e7b0f9a65 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 2 Jul 2025 14:56:09 -0400 Subject: [PATCH 4/5] no infiltrate --- ext/LinearSolveForwardDiffExt.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 9197d01b2..159782f30 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -70,7 +70,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to cache.linear_cache.b = primal_b - Main.@infiltrate partial_sols = rhs_list @@ -171,10 +170,9 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) cache::DualLinearCache, cache.alg, args...; kwargs...) dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) - Main.@infiltrate cache.dual_u = dual_sol - Main.@infiltrate + return SciMLBase.build_linear_solution( cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats ) From 295a3a3240e0c1ed2aed3e8d0c98b3370cab10ea Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 3 Jul 2025 12:26:22 -0400 Subject: [PATCH 5/5] exclude nested duals --- ext/LinearSolveForwardDiffExt.jl | 32 +++++++++++++++++++++++--------- test/forwarddiff_overloads.jl | 2 +- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 159782f30..6923c1bb2 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -7,31 +7,45 @@ using ForwardDiff: Dual, Partials using SciMLBase using RecursiveArrayTools -const DualLinearProblem = LinearProblem{ + +# Define type for non-nested dual numbers +const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Float64 , P} + +# Define type for nested dual numbers +const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P} + +const SingleDualLinearProblem = LinearProblem{ + <:Union{Number, <:AbstractArray, Nothing}, iip, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, + <:Any +} where {iip} + +const NestedDualLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}}, + <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}}, <:Any -} where {iip, T, V, P} +} where {iip} const DualALinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, <:Union{Number, <:AbstractArray}, <:Any -} where {iip, T, V, P} +} where {iip} const DualBLinearProblem = LinearProblem{ <:Union{Number, <:AbstractArray, Nothing}, iip, <:Union{Number, <:AbstractArray}, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}}, <:Any -} where {iip, T, V, P} +} where {iip} const DualAbstractLinearProblem = Union{ - DualLinearProblem, DualALinearProblem, DualBLinearProblem} + SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem} LinearSolve.@concrete mutable struct DualLinearCache linear_cache diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index eb66c64dc..d73ec1746 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -79,4 +79,4 @@ cache.b = new_b x_p = solve!(cache) backslash_x_p = A \ new_b -@test ≈(x_p, backslash_x_p, rtol = 1e-9) \ No newline at end of file +@test ≈(x_p, backslash_x_p, rtol = 1e-9)