Skip to content

Commit 8a6951a

Browse files
Jammy2211Jammy2211
authored andcommitted
gaussian kernel converted successfully
1 parent 1d6043b commit 8a6951a

6 files changed

Lines changed: 65 additions & 51 deletions

File tree

autoarray/inversion/inversion/abstract.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,9 @@ def log_det_regularization_matrix_term(self) -> float:
663663
float
664664
The log determinant of the regularization matrix.
665665
"""
666+
if not self.has(cls=AbstractRegularization):
667+
return 0.0
668+
666669
try:
667670
return 2.0 * np.sum(
668671
jnp.log(

autoarray/inversion/regularization/constant_zeroth.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,4 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
5555
coefficient=self.coefficient_neighbor,
5656
coefficient_zeroth=self.coefficient_zeroth,
5757
neighbors=linear_obj.neighbors,
58-
neighbors_sizes=linear_obj.neighbors.sizes,
5958
)

autoarray/inversion/regularization/gaussian_kernel.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
import jax.numpy as jnp
23
import numpy as np
34
from typing import TYPE_CHECKING
45

@@ -7,52 +8,45 @@
78

89
from autoarray.inversion.regularization.abstract import AbstractRegularization
910

10-
from autoarray import numba_util
11-
12-
13-
@numba_util.jit()
1411
def gauss_cov_matrix_from(
1512
scale: float,
16-
pixel_points: np.ndarray,
17-
) -> np.ndarray:
13+
pixel_points: jnp.ndarray, # shape (N, 2)
14+
) -> jnp.ndarray:
1815
"""
19-
Consutruct the source brightness covariance matrix, which is used to determined the regularization
20-
pattern (i.e, how the different source pixels are smoothed).
16+
Construct the source‐pixel Gaussian covariance matrix for regularization.
17+
18+
For N source‐pixels at coordinates (y_i, x_i), we define
2119
22-
the covariance matrix includes one non-linear parameters, the scale coefficient, which is used to
23-
determine the typical scale of the regularization pattern.
20+
C_ij = exp( -||p_i - p_j||^2 / (2 scale^2) )
21+
22+
plus a tiny diagonal “jitter” (1e-8) to ensure numerical stability.
2423
2524
Parameters
2625
----------
2726
scale
28-
the typical scale of the regularization pattern .
27+
The characteristic length scale of the Gaussian kernel.
2928
pixel_points
30-
An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane).
31-
Something like [[y1,x1], [y2,x2], ...]
29+
Array of shape (N, 2), giving the (y, x) coordinates of each source pixel.
3230
3331
Returns
3432
-------
35-
np.ndarray
36-
The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels].
33+
cov : jnp.ndarray, shape (N, N)
34+
The Gaussian covariance matrix.
3735
"""
36+
# Ensure array:
37+
pts = jnp.asarray(pixel_points) # (N, 2)
38+
# Compute squared distances: ||p_i - p_j||^2
39+
diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2)
40+
d2 = jnp.sum(diffs**2, axis=-1) # (N, N)
3841

39-
pixels = len(pixel_points)
40-
covariance_matrix = np.zeros(shape=(pixels, pixels))
41-
42-
for i in range(pixels):
43-
covariance_matrix[i, i] += 1e-8
44-
for j in range(pixels):
45-
xi = pixel_points[i, 1]
46-
yi = pixel_points[i, 0]
47-
xj = pixel_points[j, 1]
48-
yj = pixel_points[j, 0]
49-
d_ij = np.sqrt(
50-
(xi - xj) ** 2 + (yi - yj) ** 2
51-
) # distance between the pixel i and j
42+
# Gaussian kernel
43+
cov = jnp.exp(-d2 / (2.0 * scale**2)) # (N, N)
5244

53-
covariance_matrix[i, j] += np.exp(-1.0 * d_ij**2 / (2 * scale**2))
45+
# Add tiny jitter on the diagonal
46+
N = pts.shape[0]
47+
cov = cov + jnp.eye(N, dtype=cov.dtype) * 1e-8
5448

55-
return covariance_matrix
49+
return cov
5650

5751

5852
class GaussianKernel(AbstractRegularization):
@@ -117,7 +111,8 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
117111
The regularization matrix.
118112
"""
119113
covariance_matrix = gauss_cov_matrix_from(
120-
scale=self.scale, pixel_points=np.array(linear_obj.source_plane_mesh_grid)
114+
scale=self.scale,
115+
pixel_points=linear_obj.source_plane_mesh_grid.array
121116
)
122117

123-
return self.coefficient * np.linalg.inv(covariance_matrix)
118+
return self.coefficient * jnp.linalg.inv(covariance_matrix)

autoarray/inversion/regularization/regularization_util.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,10 @@ class in the module `autoarray.inversion.regularization`.
8282
)
8383

8484

85-
86-
@numba_util.jit()
8785
def constant_zeroth_regularization_matrix_from(
8886
coefficient: float,
8987
coefficient_zeroth: float,
9088
neighbors: np.ndarray,
91-
neighbors_sizes: np.ndarray,
9289
) -> np.ndarray:
9390
"""
9491
From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme.
@@ -113,24 +110,34 @@ class in the module ``autoarray.inversion.regularization``.
113110
The regularization matrix computed using Regularization where the effective regularization
114111
coefficient of every source pixel is the same.
115112
"""
113+
S, P = neighbors.shape
114+
reg1 = coefficient**2
115+
reg0 = coefficient_zeroth**2
116+
117+
# 1) Flatten (i,j) neighbor‐pairs
118+
I = jnp.repeat(jnp.arange(S), P) # (S*P,)
119+
J = neighbors.reshape(-1) # (S*P,)
116120

117-
pixels = len(neighbors)
121+
# 2) Remap “no neighbor” = -1 → OUT = S
122+
OUT = S
123+
J = jnp.where(J < 0, OUT, J)
118124

119-
regularization_matrix = np.zeros(shape=(pixels, pixels))
125+
# 3) Start on an (S+1)x(S+1) zero canvas
126+
M = jnp.zeros((S+1, S+1), dtype=jnp.float32)
120127

121-
regularization_coefficient = coefficient**2.0
122-
regularization_coefficient_zeroth = coefficient_zeroth**2.0
128+
# 4) Diagonal baseline: 1e-8 + reg0 for i in [0..S-1]
129+
diag_base = jnp.concatenate([jnp.full((S,), 1e-8 + reg0), jnp.zeros((1,))])
130+
M = M.at[jnp.diag_indices(S+1)].add(diag_base)
123131

124-
for i in range(pixels):
125-
regularization_matrix[i, i] += 1e-8
126-
regularization_matrix[i, i] += regularization_coefficient_zeroth
127-
for j in range(neighbors_sizes[i]):
128-
neighbor_index = neighbors[i, j]
129-
regularization_matrix[i, i] += regularization_coefficient
130-
regularization_matrix[i, neighbor_index] -= regularization_coefficient
132+
# 5) Scatter the first-order reg1 into diag[i] for each neighbor (i→j):
133+
# M[i,i] += reg1
134+
M = M.at[I, I].add(reg1)
131135

132-
return regularization_matrix
136+
# 6) Scatter the off-diagonals: M[i,j] -= reg1
137+
M = M.at[I, J].add(-reg1)
133138

139+
# 7) Return only the top‐left S×S block
140+
return M[:S, :S]
134141

135142
def adaptive_regularization_weights_from(
136143
inner_coefficient: float, outer_coefficient: float, pixel_signals: np.ndarray

autoarray/preloads.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(
1414
self,
1515
mapper_indices: np.ndarray = None,
1616
source_pixel_zeroed_indices: np.ndarray = None,
17+
linear_light_profile_blurred_mapping_matrix = None,
1718
):
1819
"""
1920
Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance
@@ -37,11 +38,17 @@ def __init__(
3738
Indices of source pixels that should be set to zero in the reconstruction. These typically correspond to
3839
outer-edge source-plane regions with no image-plane mapping (e.g. outside a circular mask), helping
3940
separate the lens light from the pixelized source model.
41+
linear_light_profile_blurred_mapping_matrix
42+
The evaluated images of the linear light profiles that make up the blurred mapping matrix component of the
43+
inversion, with the other component being the pixelization's pixels. These are fixed when the lens light
44+
is fixed to the maximum likelihood solution, allowing the blurred mapping matrix to be preloaded, but
45+
the intensity values will still be solved for during the inversion.
4046
"""
4147

4248
self.mapper_indices = None
4349
self.source_pixel_zeroed_indices = None
4450
self.source_pixel_zeroed_indices_to_keep = None
51+
self.linear_light_profile_blurred_mapping_matrix = None
4552

4653
if mapper_indices is not None:
4754

@@ -58,3 +65,9 @@ def __init__(
5865

5966
# Get the indices where values_to_solve is True
6067
self.source_pixel_zeroed_indices_to_keep = jnp.where(values_to_solve)[0]
68+
69+
if linear_light_profile_blurred_mapping_matrix is not None:
70+
71+
self.linear_light_profile_blurred_mapping_matrix = jnp.array(
72+
linear_light_profile_blurred_mapping_matrix
73+
)

test_autoarray/inversion/regularizations/test_regularization_util.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,11 @@ def test__constant_regularization_matrix_from():
147147
def test__constant_zeroth_regularization_matrix_from():
148148
neighbors = np.array([[1, 2, -1], [0, -1, -1], [0, -1, -1]])
149149

150-
neighbors_sizes = np.array([2, 1, 1])
151-
152150
regularization_matrix = (
153151
aa.util.regularization.constant_zeroth_regularization_matrix_from(
154152
coefficient=2.0,
155153
coefficient_zeroth=0.5,
156154
neighbors=neighbors,
157-
neighbors_sizes=neighbors_sizes,
158155
)
159156
)
160157

0 commit comments

Comments
 (0)