Skip to content

Commit fada5dd

Browse files
authored
Adding check for unsorted input coordinates when using QuasisepSolver (#123)
* adding exception and tests * adding news * fixing handling of coord_to_sortable for composite kernels
1 parent 2bdccea commit fada5dd

File tree

6 files changed

+74
-15
lines changed

6 files changed

+74
-15
lines changed

docs/tutorials/quasisep-custom.ipynb

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@
670670
],
671671
"metadata": {
672672
"kernelspec": {
673-
"display_name": "Python 3 (ipykernel)",
673+
"display_name": "Python 3.10.6 ('tinygp')",
674674
"language": "python",
675675
"name": "python3"
676676
},
@@ -684,7 +684,12 @@
684684
"name": "python",
685685
"nbconvert_exporter": "python",
686686
"pygments_lexer": "ipython3",
687-
"version": "3.9.9"
687+
"version": "3.10.6"
688+
},
689+
"vscode": {
690+
"interpreter": {
691+
"hash": "d20ea8a315da34b3e8fab0dbd7b542a0ef3c8cf12937343660e6bc10a20768e3"
692+
}
688693
}
689694
},
690695
"nbformat": 4,

news/123.feature

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added check for sorted input coordinates when using the ``QuasisepSolver``;
2+
a ``ValueError`` is thrown if they are not.

src/tinygp/gp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
solver: Optional[Any] = None,
6767
mean_value: Optional[JAXArray] = None,
6868
covariance_value: Optional[Any] = None,
69+
**solver_kwargs: Any,
6970
):
7071
self.kernel = kernel
7172
self.X = X
@@ -101,7 +102,11 @@ def __init__(
101102
else:
102103
solver = DirectSolver
103104
self.solver = solver.init(
104-
kernel, self.X, self.noise, covariance=covariance_value
105+
kernel,
106+
self.X,
107+
self.noise,
108+
covariance=covariance_value,
109+
**solver_kwargs,
105110
)
106111

107112
@property

src/tinygp/kernels/quasisep.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
]
2828

2929
from abc import ABCMeta, abstractmethod
30-
from typing import Optional, Union
30+
from typing import Any, Optional, Union
3131

3232
import jax
3333
import jax.numpy as jnp
@@ -151,7 +151,7 @@ def __add__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
151151
)
152152
return Sum(self, other)
153153

154-
def __radd__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
154+
def __radd__(self, other: Any) -> "Kernel":
155155
# We'll hit this first branch when using the `sum` function
156156
if other == 0:
157157
return self
@@ -171,7 +171,7 @@ def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
171171
)
172172
return Scale(kernel=self, scale=other)
173173

174-
def __rmul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
174+
def __rmul__(self, other: Any) -> "Kernel":
175175
if isinstance(other, Quasisep):
176176
return Product(other, self)
177177
if isinstance(other, Kernel) or jnp.ndim(other) != 0:
@@ -204,6 +204,9 @@ class Wrapper(Quasisep, metaclass=ABCMeta):
204204

205205
kernel: Quasisep
206206

207+
def coord_to_sortable(self, X: JAXArray) -> JAXArray:
208+
return self.kernel.coord_to_sortable(X)
209+
207210
def design_matrix(self) -> JAXArray:
208211
return self.kernel.design_matrix()
209212

@@ -226,6 +229,10 @@ class Sum(Quasisep):
226229
kernel1: Quasisep
227230
kernel2: Quasisep
228231

232+
def coord_to_sortable(self, X: JAXArray) -> JAXArray:
233+
"""We assume that both kernels use the same coordinates"""
234+
return self.kernel1.coord_to_sortable(X)
235+
229236
def design_matrix(self) -> JAXArray:
230237
return jsp.linalg.block_diag(
231238
self.kernel1.design_matrix(), self.kernel2.design_matrix()
@@ -259,6 +266,10 @@ class Product(Quasisep):
259266
kernel1: Quasisep
260267
kernel2: Quasisep
261268

269+
def coord_to_sortable(self, X: JAXArray) -> JAXArray:
270+
"""We assume that both kernels use the same coordinates"""
271+
return self.kernel1.coord_to_sortable(X)
272+
262273
def design_matrix(self) -> JAXArray:
263274
F1 = self.kernel1.design_matrix()
264275
F2 = self.kernel2.design_matrix()
@@ -699,14 +710,14 @@ def init(
699710
params = jnp.linalg.solve(
700711
params, 0.5 * sigma**2 * jnp.eye(p, 1, k=-p + 1)
701712
)[:, 0]
702-
stn = []
713+
stn_ = []
703714
for j in range(p):
704-
stn.append([jnp.zeros(()) for _ in range(p)])
715+
stn_.append([jnp.zeros(()) for _ in range(p)])
705716
for n, k in enumerate(range(j - 2, -1, -2)):
706-
stn[-1][k] = (2 * (n % 2) - 1) * params[j - n - 1]
717+
stn_[-1][k] = (2 * (n % 2) - 1) * params[j - n - 1]
707718
for n, k in enumerate(range(j, p, 2)):
708-
stn[-1][k] = (1 - 2 * (n % 2)) * params[n + j]
709-
stn = jnp.array(list(map(jnp.stack, stn)))
719+
stn_[-1][k] = (1 - 2 * (n % 2)) * params[n + j]
720+
stn = jnp.array(list(map(jnp.stack, stn_)))
710721

711722
return cls(
712723
sigma=sigma,

src/tinygp/solvers/quasisep/solver.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
__all__ = ["QuasisepSolver"]
66

7-
from typing import Any, Optional
7+
from typing import TYPE_CHECKING, Any, Optional
88

99
import jax
1010
import jax.numpy as jnp
@@ -41,6 +41,7 @@ def init(
4141
noise: Noise,
4242
*,
4343
covariance: Optional[Any] = None,
44+
assume_sorted: bool = False,
4445
) -> "QuasisepSolver":
4546
"""Build a :class:`QuasisepSolver` for a given kernel and coordinates
4647
@@ -52,15 +53,24 @@ def init(
5253
covariance: Optionally, a pre-computed
5354
:class:`tinygp.solvers.quasisep.core.QSM` with the covariance
5455
matrix.
56+
assume_sorted: If ``True``, assume that the input coordinates are
57+
sorted. If ``False``, check that they are sorted and throw an
58+
error if they are not. This can introduce a runtime overhead,
59+
and you can pass ``assume_sorted=True`` to get the best
60+
performance.
5561
"""
5662
from tinygp.kernels.quasisep import Quasisep
5763

5864
if covariance is None:
59-
assert isinstance(kernel, Quasisep)
65+
if TYPE_CHECKING:
66+
assert isinstance(kernel, Quasisep)
67+
if not assume_sorted:
68+
jax.debug.callback(_check_sorted, kernel.coord_to_sortable(X))
6069
matrix = kernel.to_symm_qsm(X)
6170
matrix += noise.to_qsm()
6271
else:
63-
assert isinstance(covariance, SymmQSM)
72+
if TYPE_CHECKING:
73+
assert isinstance(covariance, SymmQSM)
6474
matrix = covariance
6575
factor = matrix.cholesky()
6676
return cls(X=X, matrix=matrix, factor=factor)
@@ -125,3 +135,10 @@ def condition(
125135

126136
A = self.solve_triangular(Ks)
127137
return Kss - A.transpose() @ A
138+
139+
140+
def _check_sorted(X: JAXArray) -> None:
141+
if np.any(np.diff(X) < 0.0):
142+
raise ValueError(
143+
"Input coordinates must be sorted in order to use the QuasisepSolver"
144+
)

tests/test_solvers/test_quasisep/test_solver.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_consistent_with_direct(kernel_pair, data):
109109

110110
@pytest.mark.skipif(celerite is None, reason="'celerite' must be installed")
111111
def test_celerite(data):
112-
x, y, t = data
112+
x, y, _ = data
113113
yerr = 0.1
114114

115115
a, b, c, d = 1.1, 0.8, 0.9, 0.1
@@ -125,3 +125,22 @@ def test_celerite(data):
125125
calc = gp.log_probability(y)
126126

127127
np.testing.assert_allclose(calc, expected)
128+
129+
130+
def test_unsorted(data):
131+
random = np.random.default_rng(0)
132+
inds = random.permutation(len(data[0]))
133+
x_ = data[0][inds]
134+
y_ = data[1][inds]
135+
136+
kernel = quasisep.Matern32(sigma=1.8, scale=1.5)
137+
with pytest.raises(ValueError):
138+
GaussianProcess(kernel, x_, diag=0.1)
139+
140+
@jax.jit
141+
def impl(X, y):
142+
return GaussianProcess(kernel, X, diag=0.1).log_probability(y)
143+
144+
with pytest.raises(jax.lib.xla_extension.XlaRuntimeError) as exc_info:
145+
impl(x_, y_).block_until_ready()
146+
assert exc_info.match(r"Input coordinates must be sorted")

0 commit comments

Comments
 (0)