@@ -221,8 +221,12 @@ def _init_sundials(self):
221
221
222
222
self ._ode = check (lib .CVodeCreate (self ._solver_kind ))
223
223
rhs = self ._problem .make_sundials_rhs ()
224
+ self ._rhs = rhs
224
225
check (lib .CVodeInit (self ._ode , rhs .cffi , 0. , self ._state_buffer .c_ptr ))
225
226
227
+ user_data_p = ffi .cast ('void *' , ffi .addressof (ffi .from_buffer (self ._user_data .data )))
228
+ check (lib .CVodeSetUserData (self ._ode , user_data_p ))
229
+
226
230
self ._set_tolerances (self ._abstol , self ._reltol )
227
231
if self ._constraints is not None :
228
232
assert self ._constraints .shape == (n_states ,)
@@ -231,9 +235,6 @@ def _init_sundials(self):
231
235
232
236
self ._make_linsol (self ._linear_solver_kind )
233
237
234
- user_data_p = ffi .cast ('void *' , ffi .addressof (ffi .from_buffer (self ._user_data .data )))
235
- check (lib .CVodeSetUserData (self ._ode , user_data_p ))
236
-
237
238
self ._compute_sens = self ._sens_mode is not None
238
239
if self ._compute_sens :
239
240
sens_rhs = self ._problem .make_sundials_sensitivity_rhs ()
@@ -340,15 +341,20 @@ def _init_sens(self, sens_rhs, sens_mode, scaling_factors=None) -> None:
340
341
def _set_tolerances (self , atol = None , rtol = None ):
341
342
atol = np .array (atol )
342
343
rtol = np .array (rtol )
344
+ n_states = self ._problem .n_states
343
345
if atol .ndim == 1 and rtol .ndim == 1 :
344
346
atol = sunode .from_numpy (atol )
345
347
rtol = sunode .from_numpy (rtol )
348
+ assert atol .shape == (n_states ,)
349
+ assert rtol .shape == (n_states ,)
346
350
check (lib .CVodeVVtolerances (self ._ode , rtol .c_ptr , atol .c_ptr ))
347
351
elif atol .ndim == 1 and rtol .ndim == 0 :
348
352
atol = sunode .from_numpy (atol )
353
+ assert atol .shape == (n_states ,)
349
354
check (lib .CVodeSVtolerances (self ._ode , rtol , atol .c_ptr ))
350
355
elif atol .ndim == 0 and rtol .ndim == 1 :
351
356
rtol = sunode .from_numpy (rtol )
357
+ assert rtol .shape == (n_states ,)
352
358
check (lib .CVodeVStolerances (self ._ode , rtol .c_ptr , atol ))
353
359
elif atol .ndim == 0 and rtol .ndim == 0 :
354
360
check (lib .CVodeSStolerances (self ._ode , rtol , atol ))
@@ -416,6 +422,7 @@ def solve(self, t0, tvals, y0, y_out, *, sens0=None, sens_out=None, max_retries=
416
422
TOO_MUCH_WORK = lib .CV_TOO_MUCH_WORK
417
423
418
424
n_params = self ._problem .n_params
425
+ n_states = self ._problem .n_states
419
426
420
427
state_data = self ._state_buffer .data
421
428
state_c_ptr = self ._state_buffer .c_ptr
@@ -428,6 +435,9 @@ def solve(self, t0, tvals, y0, y_out, *, sens0=None, sens_out=None, max_retries=
428
435
429
436
if y0 .dtype == self ._problem .state_dtype :
430
437
y0 = y0 [None ].view (np .float64 )
438
+
439
+ if y0 .shape != (n_states ,):
440
+ raise ValueError (f"y0 should have shape { (n_states ,)} but has shape { y0 .shape } ." )
431
441
state_data [:] = y0
432
442
433
443
time_p = ffi .new ('double*' )
0 commit comments