2727]
2828
2929from abc import ABCMeta , abstractmethod
30- from typing import Optional , Union
30+ from typing import Any , Optional , Union
3131
3232import jax
3333import jax .numpy as jnp
@@ -151,7 +151,7 @@ def __add__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
151151 )
152152 return Sum (self , other )
153153
154- def __radd__ (self , other : Union [ "Kernel" , JAXArray ] ) -> "Kernel" :
154+ def __radd__ (self , other : Any ) -> "Kernel" :
155155 # We'll hit this first branch when using the `sum` function
156156 if other == 0 :
157157 return self
@@ -171,7 +171,7 @@ def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
171171 )
172172 return Scale (kernel = self , scale = other )
173173
174- def __rmul__ (self , other : Union [ "Kernel" , JAXArray ] ) -> "Kernel" :
174+ def __rmul__ (self , other : Any ) -> "Kernel" :
175175 if isinstance (other , Quasisep ):
176176 return Product (other , self )
177177 if isinstance (other , Kernel ) or jnp .ndim (other ) != 0 :
@@ -204,6 +204,9 @@ class Wrapper(Quasisep, metaclass=ABCMeta):
204204
205205 kernel : Quasisep
206206
207+ def coord_to_sortable (self , X : JAXArray ) -> JAXArray :
208+ return self .kernel .coord_to_sortable (X )
209+
207210 def design_matrix (self ) -> JAXArray :
208211 return self .kernel .design_matrix ()
209212
@@ -226,6 +229,10 @@ class Sum(Quasisep):
226229 kernel1 : Quasisep
227230 kernel2 : Quasisep
228231
232+ def coord_to_sortable (self , X : JAXArray ) -> JAXArray :
233+ """We assume that both kernels use the same coordinates"""
234+ return self .kernel1 .coord_to_sortable (X )
235+
229236 def design_matrix (self ) -> JAXArray :
230237 return jsp .linalg .block_diag (
231238 self .kernel1 .design_matrix (), self .kernel2 .design_matrix ()
@@ -259,6 +266,10 @@ class Product(Quasisep):
259266 kernel1 : Quasisep
260267 kernel2 : Quasisep
261268
269+ def coord_to_sortable (self , X : JAXArray ) -> JAXArray :
270+ """We assume that both kernels use the same coordinates"""
271+ return self .kernel1 .coord_to_sortable (X )
272+
262273 def design_matrix (self ) -> JAXArray :
263274 F1 = self .kernel1 .design_matrix ()
264275 F2 = self .kernel2 .design_matrix ()
@@ -699,14 +710,14 @@ def init(
699710 params = jnp .linalg .solve (
700711 params , 0.5 * sigma ** 2 * jnp .eye (p , 1 , k = - p + 1 )
701712 )[:, 0 ]
702- stn = []
713+ stn_ = []
703714 for j in range (p ):
704- stn .append ([jnp .zeros (()) for _ in range (p )])
715+ stn_ .append ([jnp .zeros (()) for _ in range (p )])
705716 for n , k in enumerate (range (j - 2 , - 1 , - 2 )):
706- stn [- 1 ][k ] = (2 * (n % 2 ) - 1 ) * params [j - n - 1 ]
717+ stn_ [- 1 ][k ] = (2 * (n % 2 ) - 1 ) * params [j - n - 1 ]
707718 for n , k in enumerate (range (j , p , 2 )):
708- stn [- 1 ][k ] = (1 - 2 * (n % 2 )) * params [n + j ]
709- stn = jnp .array (list (map (jnp .stack , stn )))
719+ stn_ [- 1 ][k ] = (1 - 2 * (n % 2 )) * params [n + j ]
720+ stn = jnp .array (list (map (jnp .stack , stn_ )))
710721
711722 return cls (
712723 sigma = sigma ,
0 commit comments