|
1 | 1 | from __future__ import annotations |
| 2 | +import jax.numpy as jnp |
2 | 3 | import numpy as np |
3 | 4 | from typing import TYPE_CHECKING |
4 | 5 |
|
|
7 | 8 |
|
8 | 9 | from autoarray.inversion.regularization.abstract import AbstractRegularization |
9 | 10 |
|
10 | | -from autoarray import numba_util |
11 | | - |
12 | | - |
13 | | -@numba_util.jit() |
14 | 11 | def gauss_cov_matrix_from( |
15 | 12 | scale: float, |
16 | | - pixel_points: np.ndarray, |
17 | | -) -> np.ndarray: |
| 13 | + pixel_points: jnp.ndarray, # shape (N, 2) |
| 14 | +) -> jnp.ndarray: |
18 | 15 | """ |
19 | | - Consutruct the source brightness covariance matrix, which is used to determined the regularization |
20 | | - pattern (i.e, how the different source pixels are smoothed). |
| 16 | + Construct the source‐pixel Gaussian covariance matrix for regularization. |
| 17 | +
|
| 18 | + For N source‐pixels at coordinates (y_i, x_i), we define |
21 | 19 |
|
22 | | - the covariance matrix includes one non-linear parameters, the scale coefficient, which is used to |
23 | | - determine the typical scale of the regularization pattern. |
| 20 | + C_ij = exp( -||p_i - p_j||^2 / (2 scale^2) ) |
| 21 | +
|
| 22 | + plus a tiny diagonal “jitter” (1e-8) to ensure numerical stability. |
24 | 23 |
|
25 | 24 | Parameters |
26 | 25 | ---------- |
27 | 26 | scale |
28 | | - the typical scale of the regularization pattern . |
| 27 | + The characteristic length scale of the Gaussian kernel. |
29 | 28 | pixel_points |
30 | | - An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane). |
31 | | - Something like [[y1,x1], [y2,x2], ...] |
| 29 | + Array of shape (N, 2), giving the (y, x) coordinates of each source pixel. |
32 | 30 |
|
33 | 31 | Returns |
34 | 32 | ------- |
35 | | - np.ndarray |
36 | | - The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels]. |
| 33 | + cov : jnp.ndarray, shape (N, N) |
| 34 | + The Gaussian covariance matrix. |
37 | 35 | """ |
| 36 | + # Ensure array: |
| 37 | + pts = jnp.asarray(pixel_points) # (N, 2) |
| 38 | + # Compute squared distances: ||p_i - p_j||^2 |
| 39 | + diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2) |
| 40 | + d2 = jnp.sum(diffs**2, axis=-1) # (N, N) |
38 | 41 |
|
39 | | - pixels = len(pixel_points) |
40 | | - covariance_matrix = np.zeros(shape=(pixels, pixels)) |
41 | | - |
42 | | - for i in range(pixels): |
43 | | - covariance_matrix[i, i] += 1e-8 |
44 | | - for j in range(pixels): |
45 | | - xi = pixel_points[i, 1] |
46 | | - yi = pixel_points[i, 0] |
47 | | - xj = pixel_points[j, 1] |
48 | | - yj = pixel_points[j, 0] |
49 | | - d_ij = np.sqrt( |
50 | | - (xi - xj) ** 2 + (yi - yj) ** 2 |
51 | | - ) # distance between the pixel i and j |
| 42 | + # Gaussian kernel |
| 43 | + cov = jnp.exp(-d2 / (2.0 * scale**2)) # (N, N) |
52 | 44 |
|
53 | | - covariance_matrix[i, j] += np.exp(-1.0 * d_ij**2 / (2 * scale**2)) |
| 45 | + # Add tiny jitter on the diagonal |
| 46 | + N = pts.shape[0] |
| 47 | + cov = cov + jnp.eye(N, dtype=cov.dtype) * 1e-8 |
54 | 48 |
|
55 | | - return covariance_matrix |
| 49 | + return cov |
56 | 50 |
|
57 | 51 |
|
58 | 52 | class GaussianKernel(AbstractRegularization): |
@@ -117,7 +111,8 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray: |
117 | 111 | The regularization matrix. |
118 | 112 | """ |
119 | 113 | covariance_matrix = gauss_cov_matrix_from( |
120 | | - scale=self.scale, pixel_points=np.array(linear_obj.source_plane_mesh_grid) |
| 114 | + scale=self.scale, |
| 115 | + pixel_points=linear_obj.source_plane_mesh_grid.array |
121 | 116 | ) |
122 | 117 |
|
123 | | - return self.coefficient * np.linalg.inv(covariance_matrix) |
| 118 | + return self.coefficient * jnp.linalg.inv(covariance_matrix) |
0 commit comments