Skip to content

Commit 20f0e04

Browse files
committed
trying to get sorting to work
1 parent c1893ca commit 20f0e04

File tree

4 files changed

+91
-9
lines changed

4 files changed

+91
-9
lines changed

src/tinygp/gp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
solver: Optional[Any] = None,
6767
mean_value: Optional[JAXArray] = None,
6868
covariance_value: Optional[Any] = None,
69+
**solver_kwargs: Any,
6970
):
7071
self.kernel = kernel
7172
self.X = X
@@ -101,7 +102,11 @@ def __init__(
101102
else:
102103
solver = DirectSolver
103104
self.solver = solver.init(
104-
kernel, self.X, self.noise, covariance=covariance_value
105+
kernel,
106+
self.X,
107+
self.noise,
108+
covariance=covariance_value,
109+
**solver_kwargs,
105110
)
106111

107112
@property

src/tinygp/kernels/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ def __radd__(self, other: Any) -> "Kernel":
118118
# We'll hit this first branch when using the `sum` function
119119
if other == 0:
120120
return self
121-
if isinstance(other, Kernel):
122-
return Sum(other, self)
123121
return Sum(Constant(other), self)
124122

125123
def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
@@ -128,8 +126,6 @@ def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
128126
return Product(self, Constant(other))
129127

130128
def __rmul__(self, other: Any) -> "Kernel":
131-
if isinstance(other, Kernel):
132-
return Product(other, self)
133129
return Product(Constant(other), self)
134130

135131

src/tinygp/solvers/quasisep/solver.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
__all__ = ["QuasisepSolver"]
66

7-
from typing import Any, Optional
7+
from functools import wraps
8+
from typing import Any, Callable, Optional
89

9-
import jax
1010
import jax.numpy as jnp
1111
import numpy as np
1212

@@ -17,6 +17,21 @@
1717
from tinygp.solvers.solver import Solver
1818

1919

20+
def handle_sorting(func: Callable[..., JAXArray]) -> Callable[..., JAXArray]:
21+
@wraps(func)
22+
def wrapped(
23+
self: "QuasisepSolver", y: JAXArray, *args: Any, **kwargs: Any
24+
) -> JAXArray:
25+
if self.inds_to_sorted is not None:
26+
y = y[self.inds_to_sorted]
27+
r = func(self, y, *args, **kwargs)
28+
if self.sorted_to_inds is not None:
29+
return r[self.sorted_to_inds]
30+
return r
31+
32+
return wrapped
33+
34+
2035
@dataclass
2136
class QuasisepSolver(Solver):
2237
"""A scalable solver that uses quasiseparable matrices
@@ -32,6 +47,8 @@ class QuasisepSolver(Solver):
3247
X: JAXArray
3348
matrix: SymmQSM
3449
factor: LowerTriQSM
50+
inds_to_sorted: Optional[JAXArray]
51+
sorted_to_inds: Optional[JAXArray]
3552

3653
@classmethod
3754
def init(
@@ -41,6 +58,7 @@ def init(
4158
noise: Noise,
4259
*,
4360
covariance: Optional[Any] = None,
61+
sort: bool = True,
4462
) -> "QuasisepSolver":
4563
"""Build a :class:`QuasisepSolver` for a given kernel and coordinates
4664
@@ -55,27 +73,55 @@ def init(
5573
"""
5674
from tinygp.kernels.quasisep import Quasisep
5775

76+
inds_to_sorted = None
77+
sorted_to_inds = None
5878
if covariance is None:
5979
assert isinstance(kernel, Quasisep)
80+
81+
if sort:
82+
inds_to_sorted = jnp.argsort(kernel.coord_to_sortable(X))
83+
sorted_to_inds = (
84+
jnp.empty_like(inds_to_sorted)
85+
.at[inds_to_sorted]
86+
.set(jnp.arange(len(inds_to_sorted)))
87+
)
88+
X = X[inds_to_sorted]
89+
6090
matrix = kernel.to_symm_qsm(X)
6191
matrix += noise.to_qsm()
92+
6293
else:
6394
assert isinstance(covariance, SymmQSM)
6495
matrix = covariance
96+
6597
factor = matrix.cholesky()
66-
return cls(X=X, matrix=matrix, factor=factor)
98+
return cls(
99+
X=X,
100+
matrix=matrix,
101+
factor=factor,
102+
inds_to_sorted=inds_to_sorted,
103+
sorted_to_inds=sorted_to_inds,
104+
)
67105

68106
def variance(self) -> JAXArray:
69107
return self.matrix.diag.d
108+
if self.sorted_to_inds is None:
109+
return self.matrix.diag.d
110+
return self.matrix.diag.d[self.sorted_to_inds]
70111

71112
def covariance(self) -> JAXArray:
72-
return self.matrix.to_dense()
113+
cov = self.matrix.to_dense()
114+
return cov
115+
if self.sorted_to_inds is None:
116+
return cov
117+
return cov[self.sorted_to_inds[:, None], self.sorted_to_inds[None, :]]
73118

74119
def normalization(self) -> JAXArray:
75120
return jnp.sum(jnp.log(self.factor.diag.d)) + 0.5 * self.factor.shape[
76121
0
77122
] * np.log(2 * np.pi)
78123

124+
@handle_sorting
79125
def solve_triangular(
80126
self, y: JAXArray, *, transpose: bool = False
81127
) -> JAXArray:
@@ -84,6 +130,7 @@ def solve_triangular(
84130
else:
85131
return self.factor.solve(y)
86132

133+
@handle_sorting
87134
def dot_triangular(self, y: JAXArray) -> JAXArray:
88135
return self.factor @ y
89136

tests/test_solvers/test_quasisep/test_solver.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,37 @@ def test_celerite(data):
125125
calc = gp.log_probability(y)
126126

127127
np.testing.assert_allclose(calc, expected)
128+
129+
130+
def test_unsorted(data):
131+
random = np.random.default_rng(0)
132+
inds = random.permutation(len(data[0]))
133+
inds_t = random.permutation(len(data[2]))
134+
x_ = data[0][inds]
135+
y_ = data[1][inds]
136+
t_ = data[2][inds_t]
137+
138+
kernel = quasisep.Matern32(sigma=1.8, scale=1.5)
139+
gp = GaussianProcess(kernel, data[0], diag=0.1)
140+
gp_ = GaussianProcess(kernel, x_, diag=0.1)
141+
assert isinstance(gp_.solver, QuasisepSolver)
142+
143+
assert np.isfinite(gp_.log_probability(y_))
144+
np.testing.assert_allclose(
145+
gp_.log_probability(y_), gp.log_probability(data[1])
146+
)
147+
148+
np.testing.assert_allclose(
149+
gp_.solver.solve_triangular(y_),
150+
gp.solver.solve_triangular(data[1])[inds],
151+
)
152+
153+
cond = gp.condition(data[1]).gp
154+
cond_ = gp_.condition(y_).gp
155+
np.testing.assert_allclose(cond_.loc, cond.loc[inds])
156+
np.testing.assert_allclose(cond_.variance, cond.variance[inds])
157+
158+
cond = gp.condition(data[1], data[2]).gp
159+
cond_ = gp_.condition(y_, t_).gp
160+
np.testing.assert_allclose(cond_.loc, cond.loc[inds_t])
161+
np.testing.assert_allclose(cond_.variance, cond.variance[inds_t])

0 commit comments

Comments
 (0)