@@ -1197,3 +1197,197 @@ function generate_update_b(sys::System, b::AbstractVector; expression = Val{true
1197
1197
return maybe_compile_function (expression, wrap_gfw, (1 , 1 , is_split (sys)), res;
1198
1198
eval_expression, eval_module)
1199
1199
end
1200
+
1201
+ # f1 = rest
1202
+ # f2 = A * x + B * x2 + C
1203
+ function calculate_split_form (sys:: System ; sparse = false )
1204
+ rhss = [eq. rhs for eq in full_equations (sys)]
1205
+ dvs = unknowns (sys)
1206
+ A, B, x2, C = semiquadratic_form (rhss, dvs)
1207
+ if ! sparse
1208
+ A = collect (A)
1209
+ B = collect (B)
1210
+ end
1211
+ A = unwrap .(A)
1212
+ B = unwrap .(B)
1213
+ x2 = unwrap .(x2)
1214
+ C = unwrap .(C)
1215
+
1216
+ return A, B, x2, C
1217
+ end
1218
+
1219
+ const DIFFCACHE_PARAM_NAME = :__mtk_diffcache
1220
+
1221
+ function get_diffcache_param (:: Type{T} ) where {T}
1222
+ toconstant (Symbolics. variable (
1223
+ DIFFCACHE_PARAM_NAME; T = DiffCache{Vector{T}, Vector{T}}))
1224
+ end
1225
+
1226
+ # x2
1227
+ const BILINEAR_CACHEVAR = unwrap (only (@constants bilinear_xₘₜₖ:: Vector{Real} ))
1228
+ # A
1229
+ const LINEAR_MATRIX_PARAM_NAME = :linear_Aₘₜₖ
1230
+ function get_linear_matrix_param (size:: NTuple{2, Int} )
1231
+ m, n = size
1232
+ unwrap (only (@constants linear_Aₘₜₖ[1 : m, 1 : n]))
1233
+ end
1234
+ # B
1235
+ const BILINEAR_MATRIX_PARAM_NAME = :bilinear_Bₘₜₖ
1236
+ function get_bilinear_matrix_param (size:: NTuple{2, Int} )
1237
+ m, n = size
1238
+ unwrap (only (@constants bilinear_Bₘₜₖ[1 : m, 1 : n]))
1239
+ end
1240
+
1241
+ function generate_semiquadratic_functions (
1242
+ sys:: System , A, B, x2, C; expression = Val{true }, wrap_gfw = Val{false },
1243
+ eval_expression = false , eval_module = @__MODULE__ , kwargs... )
1244
+ linear_matrix_param = unwrap (getproperty (sys, LINEAR_MATRIX_PARAM_NAME))
1245
+ bilinear_matrix_param = unwrap (getproperty (sys, BILINEAR_MATRIX_PARAM_NAME))
1246
+ diffcache_par = unwrap (getproperty (sys, DIFFCACHE_PARAM_NAME))
1247
+ dvs = unknowns (sys)
1248
+ ps = reorder_parameters (sys)
1249
+ # Codegen is a bit manual, and we're manually creating an efficient IIP function.
1250
+ # Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second
1251
+ # argument.
1252
+ iip_x = generated_argument_name (2 )
1253
+ oop_x = generated_argument_name (1 )
1254
+
1255
+ f1_iip_ir = Assignment[Assignment (BILINEAR_CACHEVAR,
1256
+ term (view,
1257
+ term (PreallocationTools. get_tmp,
1258
+ diffcache_par, Symbolics. DEFAULT_OUTSYM),
1259
+ 1 : length (x2)))
1260
+ # write to x2
1261
+ Assignment (:__tmp1 , SetArray (false , BILINEAR_CACHEVAR, x2))
1262
+ # out .= C
1263
+ Assignment (
1264
+ :__tmp2 , SetArray (false , Symbolics. DEFAULT_OUTSYM, C))
1265
+ # mul!(out, B, x2, 1, 1)
1266
+ Assignment (:__tmp3 ,
1267
+ term (mul!, Symbolics. DEFAULT_OUTSYM, bilinear_matrix_param,
1268
+ BILINEAR_CACHEVAR, true , true ))]
1269
+ f1_iip = build_function_wrapper (
1270
+ sys, nothing , Symbolics. DEFAULT_OUTSYM, dvs, ps... , get_iv (sys); p_start = 3 ,
1271
+ extra_assignments = f1_iip_ir, expression = Val{true }, kwargs... )
1272
+ f1_oop = build_function_wrapper (
1273
+ sys, term (+ , term (* , bilinear_matrix_param, x2), C), dvs, ps... ,
1274
+ get_iv (sys); expression = Val{true }, iip_config = (true , false ), kwargs... )
1275
+
1276
+ f2_iip_ir = Assignment[
1277
+ Assignment (
1278
+ :__tmp1 , term (mul!, Symbolics. DEFAULT_OUTSYM, linear_matrix_param, iip_x))
1279
+ ]
1280
+ f2_iip = build_function_wrapper (
1281
+ sys, nothing , Symbolics. DEFAULT_OUTSYM, dvs, ps... , get_iv (sys); p_start = 3 ,
1282
+ extra_assignments = f2_iip_ir, expression = Val{true }, kwargs... )
1283
+ f2_oop = build_function_wrapper (
1284
+ sys, term (* , linear_matrix_param, oop_x), dvs, ps... , get_iv (sys);
1285
+ expression = Val{true }, iip_config = (true , false ), kwargs... )
1286
+
1287
+ f1 = maybe_compile_function (expression, wrap_gfw, (2 , 3 , is_split (sys)),
1288
+ (f1_oop, f1_iip); eval_expression, eval_module)
1289
+ f2 = maybe_compile_function (expression, wrap_gfw, (2 , 3 , is_split (sys)),
1290
+ (f2_oop, f2_iip); eval_expression, eval_module)
1291
+ return f1, f2
1292
+ end
1293
+
1294
+ function calculate_semiquadratic_jacobian (
1295
+ sys:: System , B, x2, C; sparse = false , massmatrix = calculate_massmatrix (sys))
1296
+ dvs = unknowns (sys)
1297
+ if sparse
1298
+ x2jac = Symbolics. sparsejacobian (x2, dvs)
1299
+ Cjac = Symbolics. sparsejacobian (C, dvs)
1300
+ else
1301
+ x2jac = Symbolics. jacobian (x2, dvs)
1302
+ Cjac = Symbolics. jacobian (C, dvs)
1303
+ end
1304
+
1305
+ f1jac = B * x2jac + Cjac
1306
+
1307
+ if sparse
1308
+ for i in 1 : length (dvs)
1309
+ massmatrix[i, i] == 0 && continue
1310
+ _iszero (f1jac[i, i]) || continue
1311
+ f1jac[i, i] = 1
1312
+ f1jac[i, i] = 0
1313
+ end
1314
+ end
1315
+
1316
+ return f1jac, x2jac, Cjac
1317
+ end
1318
+
1319
+ const COLPTR_PARAM = unwrap (only (@parameters __mtk_colptr:: Vector{Int} ))
1320
+ const ROWVAL_PARAM = unwrap (only (@parameters __mtk_rowval:: Vector{Int} ))
1321
+
1322
+ function generate_semiquadratic_jacobian (
1323
+ sys:: System , B, x2, C, f1jac, x2jac, Cjac; sparse = false ,
1324
+ expression = Val{true }, wrap_gfw = Val{false },
1325
+ eval_expression = false , eval_module = @__MODULE__ , kwargs... )
1326
+ if sparse
1327
+ @assert is_parameter (sys, COLPTR_PARAM)
1328
+ @assert is_parameter (sys, ROWVAL_PARAM)
1329
+ end
1330
+ bilinear_matrix_param = unwrap (getproperty (sys, BILINEAR_MATRIX_PARAM_NAME))
1331
+ diffcache_par = unwrap (getproperty (sys, DIFFCACHE_PARAM_NAME))
1332
+ dvs = unknowns (sys)
1333
+ ps = reorder_parameters (sys)
1334
+ # Codegen is a bit manual, and we're manually creating an efficient IIP function.
1335
+ # Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second
1336
+ # argument.
1337
+ iip_x = generated_argument_name (2 )
1338
+ oop_x = generated_argument_name (1 )
1339
+
1340
+ iip_ir = Assignment[]
1341
+ push! (iip_ir,
1342
+ Assignment (:__mtk_preallocbuf ,
1343
+ term (PreallocationTools. get_tmp, diffcache_par, Symbolics. DEFAULT_OUTSYM)))
1344
+ if sparse
1345
+ push! (
1346
+ iip_ir, Assignment (:__mtk_nzvals , term (view, :__mtk_preallocbuf , 1 : nnz (x2jac))))
1347
+ push! (iip_ir, Assignment (:__tmp1 , SetArray (false , :__mtk_nzvals , x2jac. nzvals)))
1348
+ push! (iip_ir,
1349
+ Assignment (:__mtk_x2jacbuf ,
1350
+ term (SparseMatrixCSC, size (x2jac)... ,
1351
+ COLPTR_PARAM, ROWVAL_PARAM, :__mtk_nzvals )))
1352
+ cjac_idxs = AtIndex[]
1353
+ for (i, j, v) in zip (findnz (Cjac)... )
1354
+ push! (cjac_idxs, AtIndex (CartesianIndex (i, j), v))
1355
+ end
1356
+ else
1357
+ push! (iip_ir,
1358
+ Assignment (:__mtk_x2jacbuf ,
1359
+ term (reshape, term (view, :__mtk_preallocbuf , 1 : length (x2jac)), size (x2jac))))
1360
+ push! (iip_ir, Assignment (:__tmp1 , SetArray (false , :__mtk_x2jacbuf , x2jac)))
1361
+ cjac_idxs = AtIndex[]
1362
+ for i in eachindex (Cjac)
1363
+ _iszero (Cjac[i]) && continue
1364
+ push! (cjac_idxs, AtIndex (i, Cjac[i]))
1365
+ end
1366
+ end
1367
+ push! (iip_ir, Assignment (:__tmp2 , SetArray (false , Symbolics. DEFAULT_OUTSYM, cjac_idxs)))
1368
+ push! (iip_ir,
1369
+ Assignment (:__tmp3 ,
1370
+ term (mul!, Symbolics. DEFAULT_OUTSYM,
1371
+ bilinear_matrix_param, :__mtk_x2jacbuf , true , true )))
1372
+
1373
+ jaciip = build_function_wrapper (
1374
+ sys, nothing , Symbolics. DEFAULT_OUTSYM, dvs, ps... , get_iv (sys);
1375
+ p_start = 3 , extra_assignments = iip_ir, expression = Val{true }, kwargs... )
1376
+
1377
+ make_x2 = if sparse
1378
+ MakeSparseArray (x2jac)
1379
+ else
1380
+ MakeArray (x2jac, generated_argument_name (1 ))
1381
+ end
1382
+ make_cjac = if sparse
1383
+ MakeSparseArray (Cjac)
1384
+ else
1385
+ MakeArray (Cjac, generated_argument_name (1 ))
1386
+ end
1387
+ oop_expr = term (+ , term (* , bilinear_matrix_param, make_x2), Cjac)
1388
+ jacoop = build_function_wrapper (
1389
+ sys, oop_expr, dvs, ps... , get_iv (sys); expression = Val{true }, kwargs... )
1390
+
1391
+ return maybe_compile_function (expression, wrap_gfw, (2 , 3 , is_split (sys)),
1392
+ (jacoop, jaciip); eval_expression, eval_module)
1393
+ end
0 commit comments