Skip to content

Commit 7afcafa

Browse files
Setup DAE initialize interface and fix saveat endpoint for callbacks
Fixes #231
1 parent 93e2bf0 commit 7afcafa

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1616

1717
[compat]
1818
julia = "1"
19-
DiffEqBase = "6"
19+
DiffEqBase = "6.21"
2020
DataStructures = "0.17.0"
2121
BinaryProvider = "0.5"
2222
Reexport = "0.2"

src/common_interface/integrator_utils.jl

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@ function handle_callbacks!(integrator)
2121
if !(typeof(discrete_callbacks)<:Tuple{})
2222
discrete_modified,saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator,discrete_callbacks...)
2323
end
24-
if !saved_in_cb
25-
savevalues!(integrator)
26-
end
2724

2825
integrator.u_modified = continuous_modified || discrete_modified
2926
if integrator.u_modified
3027
handle_callback_modifiers!(integrator)
3128
end
29+
30+
if !saved_in_cb
31+
savevalues!(integrator)
32+
end
33+
3234
integrator.u_modified = false
3335
end
3436

@@ -37,10 +39,11 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator,force_sav
3739
!integrator.opts.save_on && return saved, savedexactly
3840
uType = eltype(integrator.sol.u)
3941
while !isempty(integrator.opts.saveat) &&
40-
integrator.tdir*top(integrator.opts.saveat) < integrator.tdir*first(integrator.tout)
42+
integrator.tdir*top(integrator.opts.saveat) < integrator.tdir*integrator.t
4143

4244
saved = true
4345
curt = pop!(integrator.opts.saveat)
46+
4447
tmp = integrator(curt)
4548
save_value!(integrator.sol.u,tmp,uType,integrator.sizeu,Val{false})
4649
push!(integrator.sol.t,curt)
@@ -88,19 +91,7 @@ end
8891

8992
function handle_callback_modifiers!(integrator::IDAIntegrator)
9093
IDAReInit(integrator.mem,integrator.t,integrator.u,integrator.du)
91-
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
92-
if any(abs.(integrator.tmp) .>= integrator.opts.reltol)
93-
if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all
94-
error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.")
95-
end
96-
if integrator.alg.init_all
97-
init_type = IDA_Y_INIT
98-
else
99-
init_type = IDA_YA_YDP_INIT
100-
integrator.flag = IDASetId(integrator.mem, integrator.sol.prob.differential_vars)
101-
end
102-
integrator.flag = IDACalcIC(integrator.mem, init_type, integrator.dt)
103-
end
94+
DiffEqBase.initialize_dae!(integrator)
10495
end
10596

10697
function DiffEqBase.add_tstop!(integrator::AbstractSundialsIntegrator,t)
@@ -154,3 +145,23 @@ end
154145
return getfield(integrator, sym)
155146
end
156147
end
148+
149+
DiffEqBase.reeval_internals_due_to_modification!(integrator::AbstractSundialsIntegrator) = nothing
150+
DiffEqBase.reeval_internals_due_to_modification!(integrator::IDAIntegrator) = handle_callback_modifiers!(integrator::IDAIntegrator)
151+
152+
DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator) = nothing
153+
function DiffEqBase.initialize_dae!(integrator::IDAIntegrator)
154+
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
155+
if any(abs.(integrator.tmp) .>= integrator.opts.reltol)
156+
if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all
157+
error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.")
158+
end
159+
if integrator.alg.init_all
160+
init_type = IDA_Y_INIT
161+
else
162+
init_type = IDA_YA_YDP_INIT
163+
integrator.flag = IDASetId(integrator.mem, integrator.sol.prob.differential_vars)
164+
end
165+
integrator.flag = IDACalcIC(integrator.mem, init_type, integrator.dt)
166+
end
167+
end

0 commit comments

Comments
 (0)