Skip to content

Commit 312e368

Browse files
committed
Add some safety checks
1 parent 93e2195 commit 312e368

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

sunode/problem.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def make_rhs_jac_prod(self): # type: ignore
4949
return NotImplemented
5050

5151
def make_user_data(self) -> np.ndarray:
52-
return np.recarray((), dtype=self.user_data_dtype)
52+
return np.zeros((), dtype=self.user_data_dtype).view(np.recarray)
5353

5454
def update_params(self, user_data: np.ndarray, params: np.ndarray) -> None:
5555
if not self.user_data_dtype == self.params_dtype:
@@ -166,10 +166,14 @@ def make_sundials_rhs(self) -> Any:
166166
func_type = numba.core.typing.cffi_utils.map_type(ffi.typeof('CVRhsFn'))
167167
func_type = func_type.return_type(*(func_type.args[:-1] + (user_ndtype_p,)))
168168

169+
n_states = self.n_states
170+
169171
@numba.cfunc(func_type)
170172
def rhs_wrapper(t, y_, out_, user_data_): # type: ignore
171173
y_ptr = N_VGetArrayPointer_Serial(y_)
172174
n_vars = N_VGetLength_Serial(y_)
175+
if n_vars != n_states:
176+
return -1
173177
out_ptr = N_VGetArrayPointer_Serial(out_)
174178
y = numba.carray(y_ptr, (n_vars,)).view(state_dtype)[0]
175179
out = numba.carray(out_ptr, (n_vars,))

sunode/solver.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,12 @@ def _init_sundials(self):
221221

222222
self._ode = check(lib.CVodeCreate(self._solver_kind))
223223
rhs = self._problem.make_sundials_rhs()
224+
self._rhs = rhs
224225
check(lib.CVodeInit(self._ode, rhs.cffi, 0., self._state_buffer.c_ptr))
225226

227+
user_data_p = ffi.cast('void *', ffi.addressof(ffi.from_buffer(self._user_data.data)))
228+
check(lib.CVodeSetUserData(self._ode, user_data_p))
229+
226230
self._set_tolerances(self._abstol, self._reltol)
227231
if self._constraints is not None:
228232
assert self._constraints.shape == (n_states,)
@@ -231,9 +235,6 @@ def _init_sundials(self):
231235

232236
self._make_linsol(self._linear_solver_kind)
233237

234-
user_data_p = ffi.cast('void *', ffi.addressof(ffi.from_buffer(self._user_data.data)))
235-
check(lib.CVodeSetUserData(self._ode, user_data_p))
236-
237238
self._compute_sens = self._sens_mode is not None
238239
if self._compute_sens:
239240
sens_rhs = self._problem.make_sundials_sensitivity_rhs()
@@ -340,15 +341,20 @@ def _init_sens(self, sens_rhs, sens_mode, scaling_factors=None) -> None:
340341
def _set_tolerances(self, atol=None, rtol=None):
341342
atol = np.array(atol)
342343
rtol = np.array(rtol)
344+
n_states = self._problem.n_states
343345
if atol.ndim == 1 and rtol.ndim == 1:
344346
atol = sunode.from_numpy(atol)
345347
rtol = sunode.from_numpy(rtol)
348+
assert atol.shape == (n_states,)
349+
assert rtol.shape == (n_states,)
346350
check(lib.CVodeVVtolerances(self._ode, rtol.c_ptr, atol.c_ptr))
347351
elif atol.ndim == 1 and rtol.ndim == 0:
348352
atol = sunode.from_numpy(atol)
353+
assert atol.shape == (n_states,)
349354
check(lib.CVodeSVtolerances(self._ode, rtol, atol.c_ptr))
350355
elif atol.ndim == 0 and rtol.ndim == 1:
351356
rtol = sunode.from_numpy(rtol)
357+
assert rtol.shape == (n_states,)
352358
check(lib.CVodeVStolerances(self._ode, rtol.c_ptr, atol))
353359
elif atol.ndim == 0 and rtol.ndim == 0:
354360
check(lib.CVodeSStolerances(self._ode, rtol, atol))
@@ -416,6 +422,7 @@ def solve(self, t0, tvals, y0, y_out, *, sens0=None, sens_out=None, max_retries=
416422
TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK
417423

418424
n_params = self._problem.n_params
425+
n_states = self._problem.n_states
419426

420427
state_data = self._state_buffer.data
421428
state_c_ptr = self._state_buffer.c_ptr
@@ -428,6 +435,9 @@ def solve(self, t0, tvals, y0, y_out, *, sens0=None, sens_out=None, max_retries=
428435

429436
if y0.dtype == self._problem.state_dtype:
430437
y0 = y0[None].view(np.float64)
438+
439+
if y0.shape != (n_states,):
440+
raise ValueError(f"y0 should have shape {(n_states,)} but has shape {y0.shape}.")
431441
state_data[:] = y0
432442

433443
time_p = ffi.new('double*')

sunode/wrappers/as_pytensor.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def read_dict(vals, name=None):
5353
tensor, dim_names = vals
5454
else:
5555
try:
56-
tensor, dim_names = vals, pt.as_tensor_variable(vals).shape.eval()
56+
tensor, dim_names = vals, pt.as_tensor_variable(vals, dtype="float64").shape.eval()
5757
except MissingInputError as e:
5858
raise ValueError(
5959
'Shapes of tensors need to be statically '
6060
'known or given explicitly.') from e
6161
if isinstance(dim_names, (str, int)):
6262
dim_names = (dim_names,)
63-
tensor = pt.as_tensor_variable(tensor)
63+
tensor = pt.as_tensor_variable(tensor, dtype="float64")
6464
if tensor.ndim != len(dim_names):
6565
raise ValueError(
6666
f"Dimension mismatch for {name}: Value has rank {tensor.ndim}, "
@@ -97,30 +97,30 @@ def read_dict(vals, name=None):
9797
tensor = flat_tensors[path]
9898
if isinstance(tensor, tuple):
9999
tensor, _ = tensor
100-
vars.append(pt.as_tensor_variable(tensor).reshape((-1,)))
100+
vars.append(pt.as_tensor_variable(tensor, dtype="float64").reshape((-1,)))
101101
if vars:
102102
params_subs_flat = pt.concatenate(vars)
103103
else:
104-
params_subs_flat = pt.as_tensor_variable(np.zeros(0))
104+
params_subs_flat = pt.as_tensor_variable(np.zeros(0), dtype="float64")
105105

106106
vars = []
107107
for path in problem.params_subset.remainder.subset_paths:
108108
tensor = flat_tensors[path]
109109
if isinstance(tensor, tuple):
110110
tensor, _ = tensor
111-
vars.append(pt.as_tensor_variable(tensor).reshape((-1,)))
111+
vars.append(pt.as_tensor_variable(tensor, dtype="float64").reshape((-1,)))
112112
if vars:
113113
params_remaining_flat = pt.concatenate(vars)
114114
else:
115-
params_remaining_flat = pt.as_tensor_variable(np.zeros(0))
115+
params_remaining_flat = pt.as_tensor_variable(np.zeros(0), dtype="float64")
116116

117117
flat_tensors = as_flattened(y0)
118118
vars = []
119119
for path in problem.state_subset.paths:
120120
tensor = flat_tensors[path]
121121
if isinstance(tensor, tuple):
122122
tensor, _ = tensor
123-
vars.append(pt.as_tensor_variable(tensor).reshape((-1,)))
123+
vars.append(pt.as_tensor_variable(tensor, dtype="float64").reshape((-1,)))
124124
y0_flat = pt.concatenate(vars)
125125

126126
if derivatives == 'adjoint':
@@ -261,6 +261,25 @@ class SolveODEAdjointBackward(Op):
261261

262262
__props__ = ('_solver_id', '_t0', '_tvals_id')
263263

264+
def make_nodes(self, *inputs):
265+
if len(inputs) != len(self.itypes):
266+
raise ValueError(
267+
f"We expected {len(self.itypes)} inputs but got {len(inputs)}."
268+
)
269+
if not all(it.in_same_class(inp.type) for inp, it in zip(inputs, self.itypes)):
270+
raise TypeError(
271+
f"Invalid input types for Op {self}:\n"
272+
+ "\n".join(
273+
f"Input {i}/{len(inputs)}: Expected {inp}, got {out}"
274+
for i, (inp, out) in enumerate(
275+
zip(self.itypes, (inp.type for inp in inputs)),
276+
start=1,
277+
)
278+
if inp != out
279+
)
280+
)
281+
return Apply(self, inputs, [o() for o in self.otypes])
282+
264283
def __init__(self, solver, t0, tvals):
265284
self._solver = solver
266285
self._t0 = t0

0 commit comments

Comments
 (0)