Skip to content

Commit b9dc547

Browse files
feat: add SemilinearODEFunction and SemilinearODEProblem
1 parent d7fd3e4 commit b9dc547

File tree

4 files changed

+334
-1
lines changed

4 files changed

+334
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
4545
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4646
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
4747
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
48+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
4849
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
4950
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
5051
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -142,6 +143,7 @@ OrdinaryDiffEq = "6.82.0"
142143
OrdinaryDiffEqCore = "1.15.0"
143144
OrdinaryDiffEqDefault = "1.2"
144145
OrdinaryDiffEqNonlinearSolve = "1.5.0"
146+
PreallocationTools = "0.4.27"
145147
PrecompileTools = "1"
146148
Pyomo = "0.1.0"
147149
REPL = "1"

src/ModelingToolkit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ const DQ = DynamicQuantities
9999
import DifferentiationInterface as DI
100100
using ADTypes: AutoForwardDiff
101101
import SciMLPublic: @public
102+
import PreallocationTools
103+
import PreallocationTools: DiffCache
102104

103105
export @derivatives
104106

@@ -287,6 +289,7 @@ export IntervalNonlinearProblem
287289
export OptimizationProblem, constraints
288290
export SteadyStateProblem
289291
export JumpProblem
292+
export SemilinearODEFunction, SemilinearODEProblem
290293
export alias_elimination, flatten
291294
export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream,
292295
instream

src/problems/odeproblem.jl

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,143 @@ end
9898
maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...)
9999
end
100100

101+
struct SemilinearODEFunction{iip, spec} end
102+
103+
@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}(
104+
sys::System; u0 = nothing, p = nothing, t = nothing,
105+
semiquadratic_form = nothing, semiquadratic_jacobian = nothing,
106+
eval_expression = false, eval_module = @__MODULE__,
107+
expression = Val{false}, sparse = false, check_compatibility = true,
108+
jac = false, checkbounds = false, cse = true, initialization_data = nothing,
109+
analytic = nothing, kwargs...) where {iip, specialize}
110+
check_complete(sys, SemilinearODEFunction)
111+
check_compatibility && check_compatible_system(SemilinearODEFunction, sys)
112+
113+
if semiquadratic_form === nothing
114+
sys = add_semilinear_parameters(sys)
115+
semiquadratic_form = calculate_split_form(sys; sparse)
116+
end
117+
118+
A, B, x2, C = semiquadratic_form
119+
M = calculate_massmatrix(sys)
120+
_M = concrete_massmatrix(M; sparse, u0)
121+
122+
f1, f2 = generate_semiquadratic_functions(
123+
sys, A, B, x2, C; expression, wrap_gfw = Val{true},
124+
eval_expression, eval_module, kwargs...)
125+
126+
if jac
127+
semiquadratic_jacobian = @something(semiquadratic_jacobian,
128+
calculate_semiquadratic_jacobian(sys, B, x2, C; sparse, massmatrix = _M))
129+
f1jac, x2jac, Cjac = semiquadratic_jacobian
130+
_jac = generate_semiquadratic_jacobian(
131+
sys, B, x2, C, f1jac, x2jac, Cjac; sparse, expression,
132+
wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...)
133+
_W_sparsity = f1jac
134+
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
135+
else
136+
_jac = nothing
137+
W_prototype = nothing
138+
end
139+
140+
observedfun = ObservedFunctionCache(
141+
sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse)
142+
143+
f1_args = (; f1)
144+
f1_kwargs = (; jac = _jac)
145+
f1 = maybe_codegen_scimlfn(
146+
expression, ODEFunction{iip, specialize}, f1_args; f1_kwargs...)
147+
args = (; f1, f2)
148+
149+
kwargs = (;
150+
sys = sys,
151+
jac = _jac,
152+
mass_matrix = _M,
153+
jac_prototype = W_prototype,
154+
observed = observedfun,
155+
analytic,
156+
initialization_data)
157+
kwargs = (; sys, observed = observedfun, mass_matrix = _M)
158+
159+
return maybe_codegen_scimlfn(
160+
expression, SplitFunction{iip, specialize}, args; kwargs...)
161+
end
162+
163+
struct SemilinearODEProblem{iip, spec} end
164+
165+
@fallback_iip_specialize function SemilinearODEProblem{iip, spec}(
166+
sys::System, op, tspan; check_compatibility = true,
167+
u0_eltype = nothing, expression = Val{false}, callback = nothing,
168+
jac = false, sparse = false, kwargs...) where {iip, spec}
169+
check_complete(sys, SemilinearODEProblem)
170+
check_compatibility && check_compatible_system(SemilinearODEProblem, sys)
171+
172+
A, B, x2, C = semiquadratic_form = calculate_split_form(sys)
173+
174+
semiquadratic_jacobian = nothing
175+
if jac
176+
f1jac, x2jac, Cjac = semiquadratic_jacobian = calculate_semiquadratic_jacobian(
177+
sys, B, x2, C; sparse)
178+
end
179+
180+
sys = add_semilinear_parameters(sys)
181+
linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME))
182+
bilinear_matrix_param = unwrap(getproperty(sys, BILINEAR_MATRIX_PARAM_NAME))
183+
diffcache = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME))
184+
185+
floatT = calculate_float_type(op, typeof(op))
186+
_u0_eltype = something(u0_eltype, floatT)
187+
188+
guess = copy(guesses(sys))
189+
guess[linear_matrix_param] = fill(NaN, size(A))
190+
guess[bilinear_matrix_param] = fill(NaN, size(B))
191+
@set! sys.guesses = guess
192+
defs = copy(defaults(sys))
193+
defs[linear_matrix_param] = A
194+
defs[bilinear_matrix_param] = B
195+
cachelen = jac ? length(x2jac) : length(x2)
196+
defs[diffcache] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen))
197+
@set! sys.defaults = defs
198+
199+
f, u0, p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op;
200+
t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility,
201+
semiquadratic_form, semiquadratic_jacobian, jac, sparse, u0_eltype, kwargs...)
202+
203+
kwargs = process_kwargs(
204+
sys; expression, callback, kwargs...)
205+
206+
ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem())
207+
args = (; f, u0, tspan, p)
208+
maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...)
209+
end
210+
211+
function add_semilinear_parameters(sys::System)
212+
m = length(equations(sys))
213+
n = length(unknowns(sys))
214+
linear_matrix_param = get_linear_matrix_param((m, n))
215+
bilinear_matrix_param = get_bilinear_matrix_param((m, (n^2 + n) ÷ 2))
216+
@assert !is_parameter(sys, linear_matrix_param)
217+
sys = with_additional_constant_parameter(sys, linear_matrix_param)
218+
@assert !is_parameter(sys, bilinear_matrix_param)
219+
sys = with_additional_constant_parameter(sys, bilinear_matrix_param)
220+
@assert !is_parameter(sys, get_diffcache_param(Float64))
221+
diffcache = get_diffcache_param(Float64)
222+
sys = with_additional_nonnumeric_parameter(sys, diffcache)
223+
var_to_name = copy(get_var_to_name(sys))
224+
var_to_name[LINEAR_MATRIX_PARAM_NAME] = linear_matrix_param
225+
var_to_name[BILINEAR_MATRIX_PARAM_NAME] = bilinear_matrix_param
226+
var_to_name[DIFFCACHE_PARAM_NAME] = diffcache
227+
@set! sys.var_to_name = var_to_name
228+
if get_parent(sys) !== nothing
229+
@set! sys.parent = add_semilinear_parameters(get_parent(sys))
230+
end
231+
return sys
232+
end
233+
101234
function check_compatible_system(
102235
T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
103-
Type{DAEProblem}, Type{SteadyStateProblem}},
236+
Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
237+
Type{SemilinearODEProblem}},
104238
sys::System)
105239
check_time_dependent(sys, T)
106240
check_not_dde(sys)

src/systems/codegen.jl

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,3 +1197,197 @@ function generate_update_b(sys::System, b::AbstractVector; expression = Val{true
11971197
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
11981198
eval_expression, eval_module)
11991199
end
1200+
1201+
# f1 = rest
1202+
# f2 = A * x + B * x2 + C
1203+
function calculate_split_form(sys::System; sparse = false)
1204+
rhss = [eq.rhs for eq in full_equations(sys)]
1205+
dvs = unknowns(sys)
1206+
A, B, x2, C = semiquadratic_form(rhss, dvs)
1207+
if !sparse
1208+
A = collect(A)
1209+
B = collect(B)
1210+
end
1211+
A = unwrap.(A)
1212+
B = unwrap.(B)
1213+
x2 = unwrap.(x2)
1214+
C = unwrap.(C)
1215+
1216+
return A, B, x2, C
1217+
end
1218+
1219+
const DIFFCACHE_PARAM_NAME = :__mtk_diffcache
1220+
1221+
function get_diffcache_param(::Type{T}) where {T}
1222+
toconstant(Symbolics.variable(
1223+
DIFFCACHE_PARAM_NAME; T = DiffCache{Vector{T}, Vector{T}}))
1224+
end
1225+
1226+
# x2
1227+
const BILINEAR_CACHEVAR = unwrap(only(@constants bilinear_xₘₜₖ::Vector{Real}))
1228+
# A
1229+
const LINEAR_MATRIX_PARAM_NAME = :linear_Aₘₜₖ
1230+
function get_linear_matrix_param(size::NTuple{2, Int})
1231+
m, n = size
1232+
unwrap(only(@constants linear_Aₘₜₖ[1:m, 1:n]))
1233+
end
1234+
# B
1235+
const BILINEAR_MATRIX_PARAM_NAME = :bilinear_Bₘₜₖ
1236+
function get_bilinear_matrix_param(size::NTuple{2, Int})
1237+
m, n = size
1238+
unwrap(only(@constants bilinear_Bₘₜₖ[1:m, 1:n]))
1239+
end
1240+
1241+
function generate_semiquadratic_functions(
1242+
sys::System, A, B, x2, C; expression = Val{true}, wrap_gfw = Val{false},
1243+
eval_expression = false, eval_module = @__MODULE__, kwargs...)
1244+
linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME))
1245+
bilinear_matrix_param = unwrap(getproperty(sys, BILINEAR_MATRIX_PARAM_NAME))
1246+
diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME))
1247+
dvs = unknowns(sys)
1248+
ps = reorder_parameters(sys)
1249+
# Codegen is a bit manual, and we're manually creating an efficient IIP function.
1250+
# Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second
1251+
# argument.
1252+
iip_x = generated_argument_name(2)
1253+
oop_x = generated_argument_name(1)
1254+
1255+
f1_iip_ir = Assignment[Assignment(BILINEAR_CACHEVAR,
1256+
term(view,
1257+
term(PreallocationTools.get_tmp,
1258+
diffcache_par, Symbolics.DEFAULT_OUTSYM),
1259+
1:length(x2)))
1260+
# write to x2
1261+
Assignment(:__tmp1, SetArray(false, BILINEAR_CACHEVAR, x2))
1262+
# out .= C
1263+
Assignment(
1264+
:__tmp2, SetArray(false, Symbolics.DEFAULT_OUTSYM, C))
1265+
# mul!(out, B, x2, 1, 1)
1266+
Assignment(:__tmp3,
1267+
term(mul!, Symbolics.DEFAULT_OUTSYM, bilinear_matrix_param,
1268+
BILINEAR_CACHEVAR, true, true))]
1269+
f1_iip = build_function_wrapper(
1270+
sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., get_iv(sys); p_start = 3,
1271+
extra_assignments = f1_iip_ir, expression = Val{true}, kwargs...)
1272+
f1_oop = build_function_wrapper(
1273+
sys, term(+, term(*, bilinear_matrix_param, x2), C), dvs, ps...,
1274+
get_iv(sys); expression = Val{true}, iip_config = (true, false), kwargs...)
1275+
1276+
f2_iip_ir = Assignment[
1277+
Assignment(
1278+
:__tmp1, term(mul!, Symbolics.DEFAULT_OUTSYM, linear_matrix_param, iip_x))
1279+
]
1280+
f2_iip = build_function_wrapper(
1281+
sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., get_iv(sys); p_start = 3,
1282+
extra_assignments = f2_iip_ir, expression = Val{true}, kwargs...)
1283+
f2_oop = build_function_wrapper(
1284+
sys, term(*, linear_matrix_param, oop_x), dvs, ps..., get_iv(sys);
1285+
expression = Val{true}, iip_config = (true, false), kwargs...)
1286+
1287+
f1 = maybe_compile_function(expression, wrap_gfw, (2, 3, is_split(sys)),
1288+
(f1_oop, f1_iip); eval_expression, eval_module)
1289+
f2 = maybe_compile_function(expression, wrap_gfw, (2, 3, is_split(sys)),
1290+
(f2_oop, f2_iip); eval_expression, eval_module)
1291+
return f1, f2
1292+
end
1293+
1294+
function calculate_semiquadratic_jacobian(
1295+
sys::System, B, x2, C; sparse = false, massmatrix = calculate_massmatrix(sys))
1296+
dvs = unknowns(sys)
1297+
if sparse
1298+
x2jac = Symbolics.sparsejacobian(x2, dvs)
1299+
Cjac = Symbolics.sparsejacobian(C, dvs)
1300+
else
1301+
x2jac = Symbolics.jacobian(x2, dvs)
1302+
Cjac = Symbolics.jacobian(C, dvs)
1303+
end
1304+
1305+
f1jac = B * x2jac + Cjac
1306+
1307+
if sparse
1308+
for i in 1:length(dvs)
1309+
massmatrix[i, i] == 0 && continue
1310+
_iszero(f1jac[i, i]) || continue
1311+
f1jac[i, i] = 1
1312+
f1jac[i, i] = 0
1313+
end
1314+
end
1315+
1316+
return f1jac, x2jac, Cjac
1317+
end
1318+
1319+
const COLPTR_PARAM = unwrap(only(@parameters __mtk_colptr::Vector{Int}))
1320+
const ROWVAL_PARAM = unwrap(only(@parameters __mtk_rowval::Vector{Int}))
1321+
1322+
function generate_semiquadratic_jacobian(
1323+
sys::System, B, x2, C, f1jac, x2jac, Cjac; sparse = false,
1324+
expression = Val{true}, wrap_gfw = Val{false},
1325+
eval_expression = false, eval_module = @__MODULE__, kwargs...)
1326+
if sparse
1327+
@assert is_parameter(sys, COLPTR_PARAM)
1328+
@assert is_parameter(sys, ROWVAL_PARAM)
1329+
end
1330+
bilinear_matrix_param = unwrap(getproperty(sys, BILINEAR_MATRIX_PARAM_NAME))
1331+
diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME))
1332+
dvs = unknowns(sys)
1333+
ps = reorder_parameters(sys)
1334+
# Codegen is a bit manual, and we're manually creating an efficient IIP function.
1335+
# Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second
1336+
# argument.
1337+
iip_x = generated_argument_name(2)
1338+
oop_x = generated_argument_name(1)
1339+
1340+
iip_ir = Assignment[]
1341+
push!(iip_ir,
1342+
Assignment(:__mtk_preallocbuf,
1343+
term(PreallocationTools.get_tmp, diffcache_par, Symbolics.DEFAULT_OUTSYM)))
1344+
if sparse
1345+
push!(
1346+
iip_ir, Assignment(:__mtk_nzvals, term(view, :__mtk_preallocbuf, 1:nnz(x2jac))))
1347+
push!(iip_ir, Assignment(:__tmp1, SetArray(false, :__mtk_nzvals, x2jac.nzvals)))
1348+
push!(iip_ir,
1349+
Assignment(:__mtk_x2jacbuf,
1350+
term(SparseMatrixCSC, size(x2jac)...,
1351+
COLPTR_PARAM, ROWVAL_PARAM, :__mtk_nzvals)))
1352+
cjac_idxs = AtIndex[]
1353+
for (i, j, v) in zip(findnz(Cjac)...)
1354+
push!(cjac_idxs, AtIndex(CartesianIndex(i, j), v))
1355+
end
1356+
else
1357+
push!(iip_ir,
1358+
Assignment(:__mtk_x2jacbuf,
1359+
term(reshape, term(view, :__mtk_preallocbuf, 1:length(x2jac)), size(x2jac))))
1360+
push!(iip_ir, Assignment(:__tmp1, SetArray(false, :__mtk_x2jacbuf, x2jac)))
1361+
cjac_idxs = AtIndex[]
1362+
for i in eachindex(Cjac)
1363+
_iszero(Cjac[i]) && continue
1364+
push!(cjac_idxs, AtIndex(i, Cjac[i]))
1365+
end
1366+
end
1367+
push!(iip_ir, Assignment(:__tmp2, SetArray(false, Symbolics.DEFAULT_OUTSYM, cjac_idxs)))
1368+
push!(iip_ir,
1369+
Assignment(:__tmp3,
1370+
term(mul!, Symbolics.DEFAULT_OUTSYM,
1371+
bilinear_matrix_param, :__mtk_x2jacbuf, true, true)))
1372+
1373+
jaciip = build_function_wrapper(
1374+
sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., get_iv(sys);
1375+
p_start = 3, extra_assignments = iip_ir, expression = Val{true}, kwargs...)
1376+
1377+
make_x2 = if sparse
1378+
MakeSparseArray(x2jac)
1379+
else
1380+
MakeArray(x2jac, generated_argument_name(1))
1381+
end
1382+
make_cjac = if sparse
1383+
MakeSparseArray(Cjac)
1384+
else
1385+
MakeArray(Cjac, generated_argument_name(1))
1386+
end
1387+
oop_expr = term(+, term(*, bilinear_matrix_param, make_x2), Cjac)
1388+
jacoop = build_function_wrapper(
1389+
sys, oop_expr, dvs, ps..., get_iv(sys); expression = Val{true}, kwargs...)
1390+
1391+
return maybe_compile_function(expression, wrap_gfw, (2, 3, is_split(sys)),
1392+
(jacoop, jaciip); eval_expression, eval_module)
1393+
end

0 commit comments

Comments
 (0)