11from __future__ import annotations
2- import numpy as np
2+ import jax . numpy as jnp
33from typing import TYPE_CHECKING
44
55if TYPE_CHECKING :
66 from autoarray .inversion .linear_obj .linear_obj import LinearObj
77
88from autoarray .inversion .regularization .abstract import AbstractRegularization
99
10- from autoarray import numba_util
1110
12-
13- @numba_util .jit ()
1411def exp_cov_matrix_from (
1512 scale : float ,
16- pixel_points : np .ndarray ,
17- ) -> np .ndarray :
13+ pixel_points : jnp .ndarray , # shape (N, 2)
14+ ) -> jnp .ndarray : # shape (N, N)
1815 """
19- Consutruct the source brightness covariance matrix, which is used to determined the regularization
20- pattern (i.e, how the different source pixels are smoothed).
16+ Construct the source brightness covariance matrix using an exponential kernel:
17+
18+ cov[i,j] = exp(- d_{ij} / scale)
2119
22- The covariance matrix includes one non-linear parameters, the scale coefficient, which is used to determine
23- the typical scale of the regularization pattern.
20+ with a tiny jitter 1e-8 added on the diagonal for numerical stability.
2421
2522 Parameters
2623 ----------
2724 scale
28- The typical scale of the regularization pattern .
25+ The length‐ scale of the exponential kernel .
2926 pixel_points
30- An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane).
31- Something like [[y1,x1], [y2,x2], ...]
27+ Array of shape (N, 2) giving the (y,x) coordinates of each source‐plane pixel.
3228
3329 Returns
3430 -------
35- np .ndarray
36- The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels] .
31+ jnp .ndarray, shape (N, N)
32+ The exponential covariance matrix.
3733 """
34+ # pairwise differences: shape (N, N, 2)
35+ diff = pixel_points [:, None , :] - pixel_points [None , :, :]
3836
39- pixels = len ( pixel_points )
40- covariance_matrix = np . zeros ( shape = ( pixels , pixels ) )
37+ # Euclidean distances: shape (N, N )
38+ d = jnp . linalg . norm ( diff , axis = - 1 )
4139
42- for i in range (pixels ):
43- covariance_matrix [i , i ] += 1e-8
44- for j in range (pixels ):
45- xi = pixel_points [i , 1 ]
46- yi = pixel_points [i , 0 ]
47- xj = pixel_points [j , 1 ]
48- yj = pixel_points [j , 0 ]
49- d_ij = np .sqrt (
50- (xi - xj ) ** 2 + (yi - yj ) ** 2
51- ) # distance between the pixel i and j
40+ # exponential kernel
41+ cov = jnp .exp (- d / scale )
5242
53- covariance_matrix [i , j ] += np .exp (- 1.0 * d_ij / scale )
43+ # add a small jitter on the diagonal
44+ N = pixel_points .shape [0 ]
45+ cov = cov + jnp .eye (N ) * 1e-8
5446
55- return covariance_matrix
47+ return cov
5648
5749
5850class ExponentialKernel (AbstractRegularization ):
@@ -83,7 +75,7 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0):
8375
8476 super ().__init__ ()
8577
86- def regularization_weights_from (self , linear_obj : LinearObj ) -> np .ndarray :
78+ def regularization_weights_from (self , linear_obj : LinearObj ) -> jnp .ndarray :
8779 """
8880 Returns the regularization weights of this regularization scheme.
8981
@@ -102,9 +94,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
10294 -------
10395 The regularization weights.
10496 """
105- return self .coefficient * np .ones (linear_obj .params )
97+ return self .coefficient * jnp .ones (linear_obj .params )
10698
107- def regularization_matrix_from (self , linear_obj : LinearObj ) -> np .ndarray :
99+ def regularization_matrix_from (self , linear_obj : LinearObj ) -> jnp .ndarray :
108100 """
109101 Returns the regularization matrix with shape [pixels, pixels].
110102
@@ -119,7 +111,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
119111 """
120112 covariance_matrix = exp_cov_matrix_from (
121113 scale = self .scale ,
122- pixel_points = np . array ( linear_obj .source_plane_mesh_grid ) ,
114+ pixel_points = linear_obj .source_plane_mesh_grid . array ,
123115 )
124116
125- return self .coefficient * np .linalg .inv (covariance_matrix )
117+ return self .coefficient * jnp .linalg .inv (covariance_matrix )
0 commit comments