Skip to content

Commit 6473340

Browse files
Merge pull request #505 from ChrisRackauckas-Claude/default-initialization-changes
Change default initialization algorithm to intelligent selection
2 parents 1534921 + b700c76 commit 6473340

File tree

11 files changed

+195
-117
lines changed

11 files changed

+195
-117
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Sundials"
22
uuid = "c3572dad-4567-51f8-b174-8c6c989267f4"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "4.28.0"
4+
version = "5.0.0"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -13,6 +13,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1515
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
16+
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1617
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1718
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1819
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -29,7 +30,7 @@ ArrayInterface = "7.17.1"
2930
CEnum = "0.5"
3031
DAEProblemLibrary = "0.1"
3132
DataStructures = "0.18, 0.19"
32-
DiffEqBase = "6.154"
33+
DiffEqBase = "6.190.2"
3334
DiffEqCallbacks = "4"
3435
DifferentiationInterface = "0.6, 0.7"
3536
ExplicitImports = "1"
@@ -39,6 +40,7 @@ Libdl = "1"
3940
LinearAlgebra = "1"
4041
LinearSolve = "3.40.0"
4142
Logging = "1"
43+
NonlinearSolveBase = "1.16"
4244
ModelingToolkit = "10"
4345
ODEProblemLibrary = "1"
4446
PrecompileTools = "1"

src/Sundials.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ using DiffEqBase: DiffEqBase, NonlinearFunction, ODEFunction, add_saveat!,
1010
get_tstops_array, initialize!, isinplace,
1111
reeval_internals_due_to_modification!, reinit!, savevalues!,
1212
set_proposed_dt!, solve, solve!, step!, terminate!, u_modified!,
13-
update_coefficients!, warn_compat
13+
update_coefficients!, warn_compat, DefaultInit, BrownFullBasicInit,
14+
ShampineCollocationInit
1415
using SciMLBase: AbstractSciMLOperator, DAEProblem, ODEProblem, ReturnCode,
1516
SciMLBase, SplitODEProblem, VectorContinuousCallback
1617
import Accessors: @reset
@@ -22,7 +23,11 @@ using Logging: Logging
2223
using SparseArrays: SparseArrays
2324
using LinearAlgebra: LinearAlgebra
2425

26+
27+
28+
import NonlinearSolveBase # Required for KINSOL definition to NonlinearSolve
2529
import LinearSolve # Required for initialization
30+
2631
using Libdl: Libdl
2732
using CEnum: CEnum, @cenum
2833

@@ -54,7 +59,7 @@ using Sundials_jll: Sundials_jll, libsundials_core,
5459

5560
export solve,
5661
SundialsODEAlgorithm, SundialsDAEAlgorithm, ARKODE, CVODE_BDF, CVODE_Adams, IDA,
57-
KINSOL
62+
KINSOL, DefaultInit, BrownFullBasicInit, ShampineCollocationInit
5863

5964
# some definitions from the system C headers wrapped into the types_and_consts.jl
6065
const DBL_MAX = prevfloat(Inf)
@@ -96,7 +101,7 @@ include("common_interface/verbosity.jl")
96101
include("common_interface/algorithms.jl")
97102
include("common_interface/integrator_types.jl")
98103
include("common_interface/integrator_utils.jl")
99-
include("common_interface/initialize_dae.jl")
104+
include("common_interface/initialize.jl")
100105
include("common_interface/solve.jl")
101106

102107
import PrecompileTools

src/common_interface/initialize.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# DefaultInit for all Sundials integrators - handles ModelingToolkit parameter initialization
2+
function DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator,
3+
initializealg::DefaultInit)
4+
prob = integrator.sol.prob
5+
if prob.f.initialization_data !== nothing
6+
DiffEqBase.initialize_dae!(integrator, SciMLBase.OverrideInit())
7+
else
8+
DiffEqBase.initialize_dae!(integrator, SciMLBase.CheckInit())
9+
end
10+
end
11+
12+
function DiffEqBase.initialize_dae!(integrator::IDAIntegrator,
13+
initializealg::BrownFullBasicInit)
14+
if integrator.u_modified
15+
IDAReinit!(integrator)
16+
end
17+
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
18+
tstart, tend = integrator.sol.prob.tspan
19+
# Use abstol from algorithm parameter (defaults to 1e-10)
20+
if any(abs.(integrator.tmp) .>= initializealg.abstol)
21+
if integrator.sol.prob.differential_vars === nothing
22+
error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.")
23+
end
24+
# BrownFullBasicInit only modifies algebraic variables
25+
init_type = IDA_YA_YDP_INIT
26+
# Use preallocated NVector for differential_vars
27+
if integrator.diff_vars_nvec !== nothing
28+
integrator.flag = IDASetId(integrator.mem, integrator.diff_vars_nvec)
29+
else
30+
error("differential_vars NVector not preallocated but needed for IDASetId")
31+
end
32+
dt = integrator.dt == tstart ? tend : integrator.dt
33+
integrator.flag = IDACalcIC(integrator.mem, init_type, dt)
34+
35+
# Reflect consistent initial conditions back into the integrator's
36+
# shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}).
37+
IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec)
38+
end
39+
if integrator.t == tstart && integrator.flag < 0
40+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
41+
ReturnCode.InitialFailure)
42+
end
43+
end
44+
45+
function DiffEqBase.initialize_dae!(integrator::IDAIntegrator,
46+
initializealg::ShampineCollocationInit)
47+
if integrator.u_modified
48+
IDAReinit!(integrator)
49+
end
50+
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
51+
tstart, tend = integrator.sol.prob.tspan
52+
if any(abs.(integrator.tmp) .>= integrator.opts.reltol)
53+
# ShampineCollocationInit modifies all variables
54+
init_type = IDA_Y_INIT
55+
# Use initdt from algorithm if provided, otherwise fall back to integrator.dt
56+
dt = if initializealg.initdt !== nothing
57+
initializealg.initdt
58+
elseif integrator.dt == tstart
59+
tend
60+
else
61+
integrator.dt
62+
end
63+
integrator.flag = IDACalcIC(integrator.mem, init_type, dt)
64+
65+
# Reflect consistent initial conditions back into the integrator's
66+
# shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}).
67+
IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec)
68+
end
69+
if integrator.t == tstart && integrator.flag < 0
70+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
71+
ReturnCode.InitialFailure)
72+
end
73+
end
74+
75+
function DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator,
76+
initializealg::SciMLBase.CheckInit)
77+
# Not allowed to be a DAE, so no-op
78+
end
79+
80+
function DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator, initalg::SciMLBase.NoInit) end
81+
82+
function DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator, initalg::SciMLBase.OverrideInit)
83+
prob = integrator.sol.prob
84+
nlsolve_alg = KINSOL()
85+
u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, Val(isinplace(prob)); nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)
86+
87+
if isinplace(prob)
88+
integrator.u .= u0
89+
if length(integrator.sol.u) == 1
90+
integrator.sol.u[1] .= u0
91+
end
92+
else
93+
integrator.u = u0
94+
if length(integrator.sol.u) == 1
95+
integrator.sol.u[1] = u0
96+
end
97+
end
98+
integrator.p = p
99+
sol = integrator.sol
100+
@reset sol.prob.p = integrator.p
101+
integrator.sol = sol
102+
103+
if success
104+
integrator.u_modified = true
105+
else
106+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.InitialFailure)
107+
end
108+
end
109+
110+
# Implementation of CheckInit for IDAIntegrator
111+
function DiffEqBase.initialize_dae!(integrator::IDAIntegrator,
112+
initializealg::SciMLBase.CheckInit)
113+
if integrator.u_modified
114+
IDAReinit!(integrator)
115+
end
116+
117+
# Evaluate the DAE residual at the initial conditions
118+
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
119+
120+
# Check if residuals are within tolerance
121+
if any(abs.(integrator.tmp) .>= integrator.opts.abstol)
122+
error("""
123+
DAE initialization failed with CheckInit: Initial conditions do not satisfy the DAE constraints.
124+
125+
The residual norm is $(maximum(abs.(integrator.tmp))), which exceeds the tolerance $(integrator.opts.abstol).
126+
127+
Note that the initial conditions include both `du0` (derivatives) and `u0` (states),
128+
and the choice of derivatives must be compatible with the states.
129+
130+
To resolve this issue, you have several options:
131+
1. Fix your initial conditions (both `du0` and `u0`) to satisfy the DAE constraints
132+
2. Use Brown's full basic initialization: initializealg = DiffEqBase.BrownFullBasicInit()
133+
- Optional: specify tolerance with DiffEqBase.BrownFullBasicInit(abstol=1e-8)
134+
3. Use Shampine's collocation initialization: initializealg = DiffEqBase.ShampineCollocationInit()
135+
- Optional: specify initial dt with DiffEqBase.ShampineCollocationInit(0.001)
136+
4. If using ModelingToolkit, use: initializealg = SciMLBase.OverrideInit()
137+
138+
Example for automatic initialization:
139+
solve(prob, IDA(); initializealg = DiffEqBase.BrownFullBasicInit())
140+
""")
141+
end
142+
end

src/common_interface/initialize_dae.jl

Lines changed: 0 additions & 87 deletions
This file was deleted.

src/common_interface/integrator_types.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ mutable struct ARKODEIntegrator{N,
134134
callback_cache::CallbackCacheType
135135
last_event_error::Float64
136136
initializealg::IA
137+
cfj1::Ptr{Cvoid}
138+
cfj2::Ptr{Cvoid}
137139
ctx_handle::ContextHandle
138140
end
139141

src/common_interface/integrator_utils.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ function handle_callback_modifiers!(integrator::ARKODEIntegrator{N,
119119
CallbackCacheType,
120120
ARKStepMem}) where {N, pType, solType, algType, fType, UFType, JType, oType,
121121
LStype, Atype, MLStype, Mtype, CallbackCacheType}
122-
ARKStepReInit(integrator.mem, integrator.userfun.fun2, integrator.userfun.fun,
122+
ARKStepReInit(integrator.mem, integrator.cfj2, integrator.cfj1,
123123
integrator.t, integrator.u)
124124
end
125125

@@ -155,9 +155,11 @@ function IDAReinit!(integrator::IDAIntegrator)
155155
integrator.u_modified = false
156156
end
157157

158-
function handle_callback_modifiers!(integrator::IDAIntegrator)
158+
function handle_callback_modifiers!(integrator::IDAIntegrator, callback_initializealg = nothing)
159159
# Implicitly does IDAReinit!
160-
DiffEqBase.initialize_dae!(integrator)
160+
# Use callback's initialization algorithm if provided, otherwise use integrator's
161+
initializealg = isnothing(callback_initializealg) ? integrator.initializealg : callback_initializealg
162+
DiffEqBase.initialize_dae!(integrator, initializealg)
161163
end
162164

163165
function DiffEqBase.add_tstop!(integrator::AbstractSundialsIntegrator, t)
@@ -218,13 +220,15 @@ end
218220
end
219221
end
220222

221-
function DiffEqBase.reeval_internals_due_to_modification!(integrator::AbstractSundialsIntegrator)
223+
function DiffEqBase.reeval_internals_due_to_modification!(integrator::AbstractSundialsIntegrator,
224+
continuous_modification = true; callback_initializealg = nothing)
222225
integrator.userfun.p = integrator.p
223226
nothing
224227
end
225-
function DiffEqBase.reeval_internals_due_to_modification!(integrator::IDAIntegrator)
228+
function DiffEqBase.reeval_internals_due_to_modification!(integrator::IDAIntegrator,
229+
continuous_modification = true; callback_initializealg = nothing)
226230
integrator.userfun.p = integrator.p
227-
handle_callback_modifiers!(integrator::IDAIntegrator)
231+
handle_callback_modifiers!(integrator, callback_initializealg)
228232
end
229233

230234
# Required for callbacks

0 commit comments

Comments
 (0)