|
| 1 | +import numpy as np |
| 2 | + |
1 | 3 | from sunode import SympyProblem
|
2 |
| -from sunode.solver import Solver |
| 4 | +from sunode.solver import Solver, AdjointSolver |
3 | 5 |
|
4 | 6 |
|
5 | 7 | def test_basic():
|
@@ -74,3 +76,79 @@ def rhs(t, y, p):
|
74 | 76 | }
|
75 | 77 | problem = SympyProblem(params, states, rhs, derivative_params=[])
|
76 | 78 | Solver(problem)
|
| 79 | + |
| 80 | + |
| 81 | +def check_call_solve(solver, params, deriv): |
| 82 | + solver.set_params_dict(params) |
| 83 | + |
| 84 | + time = np.linspace(0, 1) |
| 85 | + if deriv == 'forward': |
| 86 | + y_buffer, sense_buffer = solver.make_output_buffers(time) |
| 87 | + solver.solve( |
| 88 | + 0, |
| 89 | + time, |
| 90 | + np.ones_like(y_buffer[0]), |
| 91 | + y_buffer, |
| 92 | + sens0=np.zeros_like(sense_buffer[0]), |
| 93 | + sens_out=sense_buffer |
| 94 | + ) |
| 95 | + elif deriv == 'backward': |
| 96 | + y_buffer, grads_buffer, lamda_buffer = solver.make_output_buffers(time) |
| 97 | + solver.solve_forward(0, time, np.ones_like(y_buffer[0]), y_buffer) |
| 98 | + |
| 99 | + grads = np.ones((len(time), y_buffer.shape[-1])) |
| 100 | + solver.solve_backward( |
| 101 | + time[-1], |
| 102 | + time[0], |
| 103 | + time, |
| 104 | + grads, |
| 105 | + grads_buffer, |
| 106 | + lamda_buffer |
| 107 | + ) |
| 108 | + elif deriv is None: |
| 109 | + y_buffer = solver.make_output_buffers(time) |
| 110 | + solver.solve( |
| 111 | + 0, |
| 112 | + time, |
| 113 | + np.ones_like(y_buffer[0]), |
| 114 | + y_buffer, |
| 115 | + ) |
| 116 | + else: |
| 117 | + assert False |
| 118 | + |
| 119 | + |
| 120 | +def test_declare_sens(): |
| 121 | + def rhs(t, y, p): |
| 122 | + return { |
| 123 | + 'x': y.x + p.a.b, |
| 124 | + } |
| 125 | + |
| 126 | + states = { |
| 127 | + 'x': (), |
| 128 | + } |
| 129 | + |
| 130 | + params = { |
| 131 | + 'a': { |
| 132 | + 'b': () |
| 133 | + } |
| 134 | + } |
| 135 | + |
| 136 | + param_vals = { |
| 137 | + 'a': { |
| 138 | + 'b': 0.2 |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + problem = SympyProblem(params, states, rhs, derivative_params=[('a', 'b')]) |
| 143 | + |
| 144 | + solver = Solver(problem, sens_mode="simultaneous") |
| 145 | + check_call_solve(solver, param_vals, "forward") |
| 146 | + |
| 147 | + solver = Solver(problem, sens_mode="staggered") |
| 148 | + check_call_solve(solver, param_vals, "forward") |
| 149 | + |
| 150 | + solver = Solver(problem) |
| 151 | + check_call_solve(solver, param_vals, None) |
| 152 | + |
| 153 | + solver = AdjointSolver(problem) |
| 154 | + check_call_solve(solver, param_vals, "backward") |
0 commit comments