@@ -123,38 +123,86 @@ def read_dict(vals, name=None):
123
123
vars .append (pt .as_tensor_variable (tensor , dtype = "float64" ).reshape ((- 1 ,)))
124
124
y0_flat = pt .concatenate (vars )
125
125
126
+ t0 = pt .as_tensor_variable (t0 , dtype = "float64" )
127
+ tvals = pt .as_tensor_variable (tvals , dtype = "float64" )
128
+
126
129
if derivatives == 'adjoint' :
127
130
sol = solver .AdjointSolver (problem , ** solver_kwargs )
128
- wrapper = SolveODEAdjoint (sol , t0 , tvals )
129
- flat_solution = wrapper (y0_flat , params_subs_flat , params_remaining_flat )
131
+ wrapper = SolveODEAdjoint (sol )
132
+ flat_solution = wrapper (y0_flat , params_subs_flat , params_remaining_flat , t0 , tvals )
130
133
solution = problem .flat_solution_as_dict (flat_solution )
131
134
return solution , flat_solution , problem , sol , y0_flat , params_subs_flat
132
135
elif derivatives == 'forward' :
133
136
if not "sens_mode" in solver_kwargs :
134
137
raise ValueError ("When `derivatives=True`, the `solver_kwargs` must contain one of `sens_mode={\" simultaneous\" | \" staggered\" }`." )
135
138
sol = solver .Solver (problem , ** solver_kwargs )
136
- wrapper = SolveODE (sol , t0 , tvals )
137
- flat_solution , flat_sens = wrapper (y0_flat , params_subs_flat , params_remaining_flat )
139
+ wrapper = SolveODE (sol )
140
+ flat_solution , flat_sens = wrapper (y0_flat , params_subs_flat , params_remaining_flat , t0 , tvals )
138
141
solution = problem .flat_solution_as_dict (flat_solution )
139
142
return solution , flat_solution , problem , sol , y0_flat , params_subs_flat , flat_sens , wrapper
140
143
elif derivatives in [None , False ]:
141
144
sol = solver .Solver (problem , sens_mode = False )
142
145
assert False
143
146
144
147
148
+ class EvalRhs (Op ):
149
+ # params, params_fixed, y, tvals
150
+ itypes = [pt .dvector , pt .dvector , pt .dmatrix , pt .dvector ]
151
+ otypes = [pt .dmatrix ]
152
+
153
+ __props__ = ('_solver_id' ,)
154
+
155
+ def __init__ (self , solver ):
156
+ self ._solver = solver
157
+ self ._solver_id = id (solver )
158
+
159
+ self ._deriv_dtype = self ._solver .derivative_params_dtype
160
+ self ._fixed_dtype = self ._solver .remainder_params_dtype
161
+
162
+ # We only compile this when it is used, because we only need
163
+ # to evaluate this op if we need derivative with respect to
164
+ # the solution evaluation time points, and that should be
165
+ # a small minority of use cases.
166
+ self ._rhs = None
167
+
168
+ def perform (self , node , inputs , outputs ):
169
+ params , params_fixed , y , tvals = inputs
170
+
171
+ if self ._rhs is None :
172
+ self ._rhs = self ._solver ._problem .make_rhs ()
173
+
174
+ self ._solver .set_derivative_params (params .view (self ._deriv_dtype )[0 ])
175
+ self ._solver .set_remaining_params (params_fixed .view (self ._fixed_dtype )[0 ])
176
+
177
+ ok = True
178
+ retcode = 0
179
+ out = np .empty ((len (tvals ), self ._solver ._problem .n_states ))
180
+ for i , t in enumerate (tvals ):
181
+ retcode = self ._rhs (
182
+ out [i ],
183
+ t ,
184
+ y [i ].view (self ._solver ._problem .state_dtype )[0 ],
185
+ np .array (self ._solver ._user_data )[()]
186
+ )
187
+ ok = ok and not retcode
188
+ if not ok :
189
+ raise ValueError (f"Bad ode rhs return code: { retcode } " )
190
+
191
+ outputs [0 ][0 ] = out
192
+
193
+
145
194
class SolveODE (Op ):
146
- itypes = [pt .dvector , pt .dvector , pt .dvector ]
195
+ # y0, params, params_fixed, t0, tvals
196
+ itypes = [pt .dvector , pt .dvector , pt .dvector , pt .dscalar , pt .dvector ]
197
+ # y_out, y_sens_out
147
198
otypes = [pt .dmatrix , pt .dtensor3 ]
148
-
149
- __props__ = ('_solver_id' , '_t0' , '_tvals_id' )
150
-
151
- def __init__ (self , solver , t0 , tvals ):
199
+
200
+ __props__ = ('_solver_id' ,)
201
+
202
+ def __init__ (self , solver ):
152
203
self ._solver = solver
153
- self ._t0 = t0
154
- self ._y_out , self ._sens_out = solver .make_output_buffers (tvals )
155
- self ._tvals = tvals
156
204
self ._solver_id = id (solver )
157
- self . _tvals_id = id ( self . _tvals )
205
+
158
206
self ._deriv_dtype = self ._solver .derivative_params_dtype
159
207
self ._fixed_dtype = self ._solver .remainder_params_dtype
160
208
@@ -190,108 +238,101 @@ def get(val, path):
190
238
self ._sens0 = sens0 .reshape ((n_params , n_states ))
191
239
192
240
def perform (self , node , inputs , outputs ):
193
- y0 , params , params_fixed = inputs
241
+ y0 , params , params_fixed , t0 , tvals = inputs
242
+ y_out , sens_out = self ._solver .make_output_buffers (tvals )
194
243
195
244
self ._solver .set_derivative_params (params .view (self ._deriv_dtype )[0 ])
196
245
self ._solver .set_remaining_params (params_fixed .view (self ._fixed_dtype )[0 ])
197
246
198
247
try :
199
- self ._solver .solve (self ._t0 , self ._tvals , y0 , self ._y_out ,
200
- sens0 = self ._sens0 , sens_out = self ._sens_out )
248
+ self ._solver .solve (
249
+ t0 , tvals , y0 , y_out ,
250
+ sens0 = self ._sens0 , sens_out = sens_out
251
+ )
201
252
except SolverError :
202
- self . _y_out [...] = np .nan
203
- self . _sens_out [...] = np .nan
204
-
205
- outputs [0 ][0 ] = self . _y_out . copy ()
206
- outputs [1 ][0 ] = self . _sens_out . copy ()
253
+ y_out [...] = np .nan
254
+ sens_out [...] = np .nan
255
+
256
+ outputs [0 ][0 ] = y_out
257
+ outputs [1 ][0 ] = sens_out
207
258
208
259
def grad (self , inputs , g ):
209
260
g , g_grad = g
210
-
261
+ _ , params , params_fixed , t0 , tvals = inputs
262
+
211
263
assert str (g_grad ) == '<DisconnectedType>'
212
264
solution , sens = self (* inputs )
213
265
return [
214
266
pt .zeros_like (inputs [0 ]),
215
267
pt .sum (g [:, None , :] * sens , (0 , - 1 )),
216
- grad_not_implemented (self , 2 , inputs [- 1 ])
268
+ grad_not_implemented (self , 2 , params_fixed ),
269
+ grad_not_implemented (self , 3 , t0 ),
270
+ (EvalRhs (self ._solver )(params , params_fixed , solution , tvals ) * g ).sum (- 1 ),
217
271
]
218
272
219
273
220
274
class SolveODEAdjoint (Op ):
221
- itypes = [pt .dvector , pt .dvector , pt .dvector ]
275
+ # y0, params, params_fixed, t0, tvals
276
+ itypes = [pt .dvector , pt .dvector , pt .dvector , pt .dscalar , pt .dvector ]
222
277
otypes = [pt .dmatrix ]
223
278
224
- __props__ = ('_solver_id' , '_t0' , '_tvals_id' )
279
+ __props__ = ('_solver_id' ,)
225
280
226
- def __init__ (self , solver , t0 , tvals ):
281
+ def __init__ (self , solver ):
227
282
self ._solver = solver
228
- self ._t0 = t0
229
- self ._y_out , self ._grad_out , self ._lamda_out = solver .make_output_buffers (tvals )
230
- self ._tvals = tvals
231
283
self ._solver_id = id (solver )
232
- self ._tvals_id = id (self ._tvals )
233
284
self ._deriv_dtype = self ._solver .derivative_params_dtype
234
285
self ._fixed_dtype = self ._solver .remainder_params_dtype
235
286
236
287
def perform (self , node , inputs , outputs ):
237
- y0 , params , params_fixed = inputs
288
+ y0 , params , params_fixed , t0 , tvals = inputs
289
+
290
+ y_out , grad_out , lamda_out = self ._solver .make_output_buffers (tvals )
238
291
239
292
self ._solver .set_derivative_params (params .view (self ._deriv_dtype )[0 ])
240
293
self ._solver .set_remaining_params (params_fixed .view (self ._fixed_dtype )[0 ])
241
294
242
295
try :
243
- self ._solver .solve_forward (self . _t0 , self . _tvals , y0 , self . _y_out )
296
+ self ._solver .solve_forward (t0 , tvals , y0 , y_out )
244
297
except SolverError as e :
245
- self . _y_out [:] = np .nan
298
+ y_out [:] = np .nan
246
299
247
- outputs [0 ][0 ] = self . _y_out .copy ()
300
+ outputs [0 ][0 ] = y_out .copy ()
248
301
249
302
def grad (self , inputs , g ):
250
303
g , = g
251
304
252
- y0 , params , params_fixed = inputs
253
- backward = SolveODEAdjointBackward (self ._solver , self ._t0 , self ._tvals )
254
- lamda , gradient = backward (y0 , params , params_fixed , g )
255
- return [- lamda , gradient , grad_not_implemented (self , 2 , params_fixed )]
305
+ y0 , params , params_fixed , t0 , tvals = inputs
306
+ solution = self (* inputs )
307
+ backward = SolveODEAdjointBackward (self ._solver )
308
+ lamda , gradient = backward (y0 , params , params_fixed , g , t0 , tvals )
309
+
310
+ return [
311
+ - lamda ,
312
+ gradient ,
313
+ grad_not_implemented (self , 2 , params_fixed ),
314
+ grad_not_implemented (self , 3 , t0 ),
315
+ (EvalRhs (self ._solver )(params , params_fixed , solution , tvals ) * g ).sum (- 1 ),
316
+ ]
256
317
257
318
258
319
class SolveODEAdjointBackward (Op ):
259
- itypes = [pt .dvector , pt .dvector , pt .dvector , pt .dmatrix ]
320
+ # y0, params, params_fixed, g, t0, tvals
321
+ itypes = [pt .dvector , pt .dvector , pt .dvector , pt .dmatrix , pt .dscalar , pt .dvector ]
260
322
otypes = [pt .dvector , pt .dvector ]
261
323
262
- __props__ = ('_solver_id' , '_t0' , '_tvals_id' )
324
+ __props__ = ('_solver_id' ,)
263
325
264
- def make_nodes (self , * inputs ):
265
- if len (inputs ) != len (self .itypes ):
266
- raise ValueError (
267
- f"We expected { len (self .itypes )} inputs but got { len (inputs )} ."
268
- )
269
- if not all (it .in_same_class (inp .type ) for inp , it in zip (inputs , self .itypes )):
270
- raise TypeError (
271
- f"Invalid input types for Op { self } :\n "
272
- + "\n " .join (
273
- f"Input { i } /{ len (inputs )} : Expected { inp } , got { out } "
274
- for i , (inp , out ) in enumerate (
275
- zip (self .itypes , (inp .type for inp in inputs )),
276
- start = 1 ,
277
- )
278
- if inp != out
279
- )
280
- )
281
- return Apply (self , inputs , [o () for o in self .otypes ])
282
-
283
- def __init__ (self , solver , t0 , tvals ):
326
+ def __init__ (self , solver ):
284
327
self ._solver = solver
285
- self ._t0 = t0
286
- self ._y_out , self ._grad_out , self ._lamda_out = solver .make_output_buffers (tvals )
287
- self ._tvals = tvals
288
328
self ._solver_id = id (solver )
289
- self ._tvals_id = id (self ._tvals )
290
329
self ._deriv_dtype = self ._solver .derivative_params_dtype
291
330
self ._fixed_dtype = self ._solver .remainder_params_dtype
292
331
293
332
def perform (self , node , inputs , outputs ):
294
- y0 , params , params_fixed , grads = inputs
333
+ y0 , params , params_fixed , grads , t0 , tvals = inputs
334
+
335
+ y_out , grad_out , lamda_out = self ._solver .make_output_buffers (tvals )
295
336
296
337
self ._solver .set_derivative_params (params .view (self ._deriv_dtype )[0 ])
297
338
self ._solver .set_remaining_params (params_fixed .view (self ._fixed_dtype )[0 ])
@@ -300,12 +341,12 @@ def perform(self, node, inputs, outputs):
300
341
# that it was executed previously, but it isn't very expensive
301
342
# compared with the backward pass anyway.
302
343
try :
303
- self ._solver .solve_forward (self . _t0 , self . _tvals , y0 , self . _y_out )
304
- self ._solver .solve_backward (self . _tvals [- 1 ], self . _t0 , self . _tvals ,
305
- grads , self . _grad_out , self . _lamda_out )
344
+ self ._solver .solve_forward (t0 , tvals , y0 , y_out )
345
+ self ._solver .solve_backward (tvals [- 1 ], t0 , tvals ,
346
+ grads , grad_out , lamda_out )
306
347
except SolverError as e :
307
- self . _lamda_out [:] = np .nan
308
- self . _grad_out [:] = np .nan
348
+ lamda_out [:] = np .nan
349
+ grad_out [:] = np .nan
309
350
310
- outputs [0 ][0 ] = self . _lamda_out . copy ()
311
- outputs [1 ][0 ] = self . _grad_out . copy ()
351
+ outputs [0 ][0 ] = lamda_out
352
+ outputs [1 ][0 ] = grad_out
0 commit comments