Skip to content

Commit 84cc17e

Browse files
committed
Support gradients with respect to ode evaluation time
1 parent 312e368 commit 84cc17e

File tree

2 files changed

+138
-77
lines changed

2 files changed

+138
-77
lines changed

sunode/test_pytensor.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@ def dydt_dict(t, y, p):
1616
A.tag.test_value = np.array(0.9)
1717

1818

19-
time = np.linspace(0, 1)
19+
time = pt.linspace(0, 1, 5)
2020

2121
y0 = {
2222
'A': (A, ()),
2323
'B': np.array(1.),
2424
'C': np.array(1.)
2525
}
2626

27+
beta = pt.dscalar("beta")
28+
2729
params = {
2830
'alpha': np.array(1.),
29-
'beta': np.array(1.),
31+
'beta': beta,
3032
'extra': np.array([0.])
3133
}
3234

@@ -35,10 +37,28 @@ def dydt_dict(t, y, p):
3537
params=params,
3638
rhs=dydt_dict,
3739
tvals=time,
38-
t0=time[0],
40+
t0=0.,
3941
derivatives="forward",
4042
solver_kwargs=dict(sens_mode="simultaneous")
4143
)
4244

43-
func = pytensor.function([A], [solution["A"], solution["B"]])
44-
assert func(0.2)[0].shape == time.shape
45+
grad = pt.grad(solution["A"].sum(), time)
46+
47+
func = pytensor.function([A, beta], [solution["A"], solution["B"], grad])
48+
assert func(0.2, 1.)[0].shape == (5,)
49+
assert func(0.2, 1.)[2].shape == (5,)
50+
51+
solution, *_ = sunode.wrappers.as_pytensor.solve_ivp(
52+
y0=y0,
53+
params=params,
54+
rhs=dydt_dict,
55+
tvals=time,
56+
t0=0.,
57+
derivatives="adjoint",
58+
)
59+
60+
grad = pt.grad(solution["A"].sum(), time)
61+
62+
func = pytensor.function([A, beta], [solution["A"], solution["B"], grad])
63+
assert func(0.2, 1.)[0].shape == (5,)
64+
assert func(0.2, 1.)[2].shape == (5,)

sunode/wrappers/as_pytensor.py

+113-72
Original file line numberDiff line numberDiff line change
@@ -123,38 +123,86 @@ def read_dict(vals, name=None):
123123
vars.append(pt.as_tensor_variable(tensor, dtype="float64").reshape((-1,)))
124124
y0_flat = pt.concatenate(vars)
125125

126+
t0 = pt.as_tensor_variable(t0, dtype="float64")
127+
tvals = pt.as_tensor_variable(tvals, dtype="float64")
128+
126129
if derivatives == 'adjoint':
127130
sol = solver.AdjointSolver(problem, **solver_kwargs)
128-
wrapper = SolveODEAdjoint(sol, t0, tvals)
129-
flat_solution = wrapper(y0_flat, params_subs_flat, params_remaining_flat)
131+
wrapper = SolveODEAdjoint(sol)
132+
flat_solution = wrapper(y0_flat, params_subs_flat, params_remaining_flat, t0, tvals)
130133
solution = problem.flat_solution_as_dict(flat_solution)
131134
return solution, flat_solution, problem, sol, y0_flat, params_subs_flat
132135
elif derivatives == 'forward':
133136
if not "sens_mode" in solver_kwargs:
134137
raise ValueError("When `derivatives=True`, the `solver_kwargs` must contain one of `sens_mode={\"simultaneous\" | \"staggered\"}`.")
135138
sol = solver.Solver(problem, **solver_kwargs)
136-
wrapper = SolveODE(sol, t0, tvals)
137-
flat_solution, flat_sens = wrapper(y0_flat, params_subs_flat, params_remaining_flat)
139+
wrapper = SolveODE(sol)
140+
flat_solution, flat_sens = wrapper(y0_flat, params_subs_flat, params_remaining_flat, t0, tvals)
138141
solution = problem.flat_solution_as_dict(flat_solution)
139142
return solution, flat_solution, problem, sol, y0_flat, params_subs_flat, flat_sens, wrapper
140143
elif derivatives in [None, False]:
141144
sol = solver.Solver(problem, sens_mode=False)
142145
assert False
143146

144147

148+
class EvalRhs(Op):
149+
# params, params_fixed, y, tvals
150+
itypes = [pt.dvector, pt.dvector, pt.dmatrix, pt.dvector]
151+
otypes = [pt.dmatrix]
152+
153+
__props__ = ('_solver_id',)
154+
155+
def __init__(self, solver):
156+
self._solver = solver
157+
self._solver_id = id(solver)
158+
159+
self._deriv_dtype = self._solver.derivative_params_dtype
160+
self._fixed_dtype = self._solver.remainder_params_dtype
161+
162+
# We only compile this when it is used, because we only need
163+
# to evaluate this op if we need derivative with respect to
164+
# the solution evaluation time points, and that should be
165+
# a small minority of use cases.
166+
self._rhs = None
167+
168+
def perform(self, node, inputs, outputs):
169+
params, params_fixed, y, tvals = inputs
170+
171+
if self._rhs is None:
172+
self._rhs = self._solver._problem.make_rhs()
173+
174+
self._solver.set_derivative_params(params.view(self._deriv_dtype)[0])
175+
self._solver.set_remaining_params(params_fixed.view(self._fixed_dtype)[0])
176+
177+
ok = True
178+
retcode = 0
179+
out = np.empty((len(tvals), self._solver._problem.n_states))
180+
for i, t in enumerate(tvals):
181+
retcode = self._rhs(
182+
out[i],
183+
t,
184+
y[i].view(self._solver._problem.state_dtype)[0],
185+
np.array(self._solver._user_data)[()]
186+
)
187+
ok = ok and not retcode
188+
if not ok:
189+
raise ValueError(f"Bad ode rhs return code: {retcode}")
190+
191+
outputs[0][0] = out
192+
193+
145194
class SolveODE(Op):
146-
itypes = [pt.dvector, pt.dvector, pt.dvector]
195+
# y0, params, params_fixed, t0, tvals
196+
itypes = [pt.dvector, pt.dvector, pt.dvector, pt.dscalar, pt.dvector]
197+
# y_out, y_sens_out
147198
otypes = [pt.dmatrix, pt.dtensor3]
148-
149-
__props__ = ('_solver_id', '_t0', '_tvals_id')
150-
151-
def __init__(self, solver, t0, tvals):
199+
200+
__props__ = ('_solver_id',)
201+
202+
def __init__(self, solver):
152203
self._solver = solver
153-
self._t0 = t0
154-
self._y_out, self._sens_out = solver.make_output_buffers(tvals)
155-
self._tvals = tvals
156204
self._solver_id = id(solver)
157-
self._tvals_id = id(self._tvals)
205+
158206
self._deriv_dtype = self._solver.derivative_params_dtype
159207
self._fixed_dtype = self._solver.remainder_params_dtype
160208

@@ -190,108 +238,101 @@ def get(val, path):
190238
self._sens0 = sens0.reshape((n_params, n_states))
191239

192240
def perform(self, node, inputs, outputs):
193-
y0, params, params_fixed = inputs
241+
y0, params, params_fixed, t0, tvals = inputs
242+
y_out, sens_out = self._solver.make_output_buffers(tvals)
194243

195244
self._solver.set_derivative_params(params.view(self._deriv_dtype)[0])
196245
self._solver.set_remaining_params(params_fixed.view(self._fixed_dtype)[0])
197246

198247
try:
199-
self._solver.solve(self._t0, self._tvals, y0, self._y_out,
200-
sens0=self._sens0, sens_out=self._sens_out)
248+
self._solver.solve(
249+
t0, tvals, y0, y_out,
250+
sens0=self._sens0, sens_out=sens_out
251+
)
201252
except SolverError:
202-
self._y_out[...] = np.nan
203-
self._sens_out[...] = np.nan
204-
205-
outputs[0][0] = self._y_out.copy()
206-
outputs[1][0] = self._sens_out.copy()
253+
y_out[...] = np.nan
254+
sens_out[...] = np.nan
255+
256+
outputs[0][0] = y_out
257+
outputs[1][0] = sens_out
207258

208259
def grad(self, inputs, g):
209260
g, g_grad = g
210-
261+
_, params, params_fixed, t0, tvals = inputs
262+
211263
assert str(g_grad) == '<DisconnectedType>'
212264
solution, sens = self(*inputs)
213265
return [
214266
pt.zeros_like(inputs[0]),
215267
pt.sum(g[:, None, :] * sens, (0, -1)),
216-
grad_not_implemented(self, 2, inputs[-1])
268+
grad_not_implemented(self, 2, params_fixed),
269+
grad_not_implemented(self, 3, t0),
270+
(EvalRhs(self._solver)(params, params_fixed, solution, tvals) * g).sum(-1),
217271
]
218272

219273

220274
class SolveODEAdjoint(Op):
221-
itypes = [pt.dvector, pt.dvector, pt.dvector]
275+
# y0, params, params_fixed, t0, tvals
276+
itypes = [pt.dvector, pt.dvector, pt.dvector, pt.dscalar, pt.dvector]
222277
otypes = [pt.dmatrix]
223278

224-
__props__ = ('_solver_id', '_t0', '_tvals_id')
279+
__props__ = ('_solver_id',)
225280

226-
def __init__(self, solver, t0, tvals):
281+
def __init__(self, solver):
227282
self._solver = solver
228-
self._t0 = t0
229-
self._y_out, self._grad_out, self._lamda_out = solver.make_output_buffers(tvals)
230-
self._tvals = tvals
231283
self._solver_id = id(solver)
232-
self._tvals_id = id(self._tvals)
233284
self._deriv_dtype = self._solver.derivative_params_dtype
234285
self._fixed_dtype = self._solver.remainder_params_dtype
235286

236287
def perform(self, node, inputs, outputs):
237-
y0, params, params_fixed = inputs
288+
y0, params, params_fixed, t0, tvals = inputs
289+
290+
y_out, grad_out, lamda_out = self._solver.make_output_buffers(tvals)
238291

239292
self._solver.set_derivative_params(params.view(self._deriv_dtype)[0])
240293
self._solver.set_remaining_params(params_fixed.view(self._fixed_dtype)[0])
241294

242295
try:
243-
self._solver.solve_forward(self._t0, self._tvals, y0, self._y_out)
296+
self._solver.solve_forward(t0, tvals, y0, y_out)
244297
except SolverError as e:
245-
self._y_out[:] = np.nan
298+
y_out[:] = np.nan
246299

247-
outputs[0][0] = self._y_out.copy()
300+
outputs[0][0] = y_out.copy()
248301

249302
def grad(self, inputs, g):
250303
g, = g
251304

252-
y0, params, params_fixed = inputs
253-
backward = SolveODEAdjointBackward(self._solver, self._t0, self._tvals)
254-
lamda, gradient = backward(y0, params, params_fixed, g)
255-
return [-lamda, gradient, grad_not_implemented(self, 2, params_fixed)]
305+
y0, params, params_fixed, t0, tvals = inputs
306+
solution = self(*inputs)
307+
backward = SolveODEAdjointBackward(self._solver)
308+
lamda, gradient = backward(y0, params, params_fixed, g, t0, tvals)
309+
310+
return [
311+
-lamda,
312+
gradient,
313+
grad_not_implemented(self, 2, params_fixed),
314+
grad_not_implemented(self, 3, t0),
315+
(EvalRhs(self._solver)(params, params_fixed, solution, tvals) * g).sum(-1),
316+
]
256317

257318

258319
class SolveODEAdjointBackward(Op):
259-
itypes = [pt.dvector, pt.dvector, pt.dvector, pt.dmatrix]
320+
# y0, params, params_fixed, g, t0, tvals
321+
itypes = [pt.dvector, pt.dvector, pt.dvector, pt.dmatrix, pt.dscalar, pt.dvector]
260322
otypes = [pt.dvector, pt.dvector]
261323

262-
__props__ = ('_solver_id', '_t0', '_tvals_id')
324+
__props__ = ('_solver_id',)
263325

264-
def make_nodes(self, *inputs):
265-
if len(inputs) != len(self.itypes):
266-
raise ValueError(
267-
f"We expected {len(self.itypes)} inputs but got {len(inputs)}."
268-
)
269-
if not all(it.in_same_class(inp.type) for inp, it in zip(inputs, self.itypes)):
270-
raise TypeError(
271-
f"Invalid input types for Op {self}:\n"
272-
+ "\n".join(
273-
f"Input {i}/{len(inputs)}: Expected {inp}, got {out}"
274-
for i, (inp, out) in enumerate(
275-
zip(self.itypes, (inp.type for inp in inputs)),
276-
start=1,
277-
)
278-
if inp != out
279-
)
280-
)
281-
return Apply(self, inputs, [o() for o in self.otypes])
282-
283-
def __init__(self, solver, t0, tvals):
326+
def __init__(self, solver):
284327
self._solver = solver
285-
self._t0 = t0
286-
self._y_out, self._grad_out, self._lamda_out = solver.make_output_buffers(tvals)
287-
self._tvals = tvals
288328
self._solver_id = id(solver)
289-
self._tvals_id = id(self._tvals)
290329
self._deriv_dtype = self._solver.derivative_params_dtype
291330
self._fixed_dtype = self._solver.remainder_params_dtype
292331

293332
def perform(self, node, inputs, outputs):
294-
y0, params, params_fixed, grads = inputs
333+
y0, params, params_fixed, grads, t0, tvals = inputs
334+
335+
y_out, grad_out, lamda_out = self._solver.make_output_buffers(tvals)
295336

296337
self._solver.set_derivative_params(params.view(self._deriv_dtype)[0])
297338
self._solver.set_remaining_params(params_fixed.view(self._fixed_dtype)[0])
@@ -300,12 +341,12 @@ def perform(self, node, inputs, outputs):
300341
# that it was executed previously, but it isn't very expensive
301342
# compared with the backward pass anyway.
302343
try:
303-
self._solver.solve_forward(self._t0, self._tvals, y0, self._y_out)
304-
self._solver.solve_backward(self._tvals[-1], self._t0, self._tvals,
305-
grads, self._grad_out, self._lamda_out)
344+
self._solver.solve_forward(t0, tvals, y0, y_out)
345+
self._solver.solve_backward(tvals[-1], t0, tvals,
346+
grads, grad_out, lamda_out)
306347
except SolverError as e:
307-
self._lamda_out[:] = np.nan
308-
self._grad_out[:] = np.nan
348+
lamda_out[:] = np.nan
349+
grad_out[:] = np.nan
309350

310-
outputs[0][0] = self._lamda_out.copy()
311-
outputs[1][0] = self._grad_out.copy()
351+
outputs[0][0] = lamda_out
352+
outputs[1][0] = grad_out

0 commit comments

Comments
 (0)