Skip to content

Commit a25fc29

Browse files
committed
Add some tests
1 parent 7699b95 commit a25fc29

File tree

1 file changed

+79
-1
lines changed

1 file changed

+79
-1
lines changed

sunode/test_solve.py

+79-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import numpy as np
2+
13
from sunode import SympyProblem
2-
from sunode.solver import Solver
4+
from sunode.solver import Solver, AdjointSolver
35

46

57
def test_basic():
@@ -74,3 +76,79 @@ def rhs(t, y, p):
7476
}
7577
problem = SympyProblem(params, states, rhs, derivative_params=[])
7678
Solver(problem)
79+
80+
81+
def check_call_solve(solver, params, deriv):
82+
solver.set_params_dict(params)
83+
84+
time = np.linspace(0, 1)
85+
if deriv == 'forward':
86+
y_buffer, sense_buffer = solver.make_output_buffers(time)
87+
solver.solve(
88+
0,
89+
time,
90+
np.ones_like(y_buffer[0]),
91+
y_buffer,
92+
sens0=np.zeros_like(sense_buffer[0]),
93+
sens_out=sense_buffer
94+
)
95+
elif deriv == 'backward':
96+
y_buffer, grads_buffer, lamda_buffer = solver.make_output_buffers(time)
97+
solver.solve_forward(0, time, np.ones_like(y_buffer[0]), y_buffer)
98+
99+
grads = np.ones((len(time), y_buffer.shape[-1]))
100+
solver.solve_backward(
101+
time[-1],
102+
time[0],
103+
time,
104+
grads,
105+
grads_buffer,
106+
lamda_buffer
107+
)
108+
elif deriv is None:
109+
y_buffer = solver.make_output_buffers(time)
110+
solver.solve(
111+
0,
112+
time,
113+
np.ones_like(y_buffer[0]),
114+
y_buffer,
115+
)
116+
else:
117+
assert False
118+
119+
120+
def test_declare_sens():
121+
def rhs(t, y, p):
122+
return {
123+
'x': y.x + p.a.b,
124+
}
125+
126+
states = {
127+
'x': (),
128+
}
129+
130+
params = {
131+
'a': {
132+
'b': ()
133+
}
134+
}
135+
136+
param_vals = {
137+
'a': {
138+
'b': 0.2
139+
}
140+
}
141+
142+
problem = SympyProblem(params, states, rhs, derivative_params=[('a', 'b')])
143+
144+
solver = Solver(problem, sens_mode="simultaneous")
145+
check_call_solve(solver, param_vals, "forward")
146+
147+
solver = Solver(problem, sens_mode="staggered")
148+
check_call_solve(solver, param_vals, "forward")
149+
150+
solver = Solver(problem)
151+
check_call_solve(solver, param_vals, None)
152+
153+
solver = AdjointSolver(problem)
154+
check_call_solve(solver, param_vals, "backward")

0 commit comments

Comments
 (0)