Skip to content

Commit a63369d

Browse files
Jammy2211Jammy2211
authored andcommitted
regulsirztion refactor complete and rectangular works
1 parent 150a360 commit a63369d

14 files changed

Lines changed: 90 additions & 68 deletions

File tree

autoarray/inversion/pixelization/mesh/mesh_util.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def rectangular_corner_neighbors(
110110

111111
return neighbors, neighbors_sizes
112112

113+
113114
def rectangular_top_edge_neighbors(
114115
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
115116
) -> Tuple[np.ndarray, np.ndarray]:
@@ -145,6 +146,7 @@ def rectangular_top_edge_neighbors(
145146

146147
return neighbors, neighbors_sizes
147148

149+
148150
def rectangular_left_edge_neighbors(
149151
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
150152
) -> Tuple[np.ndarray, np.ndarray]:
@@ -180,6 +182,7 @@ def rectangular_left_edge_neighbors(
180182

181183
return neighbors, neighbors_sizes
182184

185+
183186
def rectangular_right_edge_neighbors(
184187
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
185188
) -> Tuple[np.ndarray, np.ndarray]:
@@ -215,6 +218,7 @@ def rectangular_right_edge_neighbors(
215218

216219
return neighbors, neighbors_sizes
217220

221+
218222
def rectangular_bottom_edge_neighbors(
219223
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
220224
) -> Tuple[np.ndarray, np.ndarray]:
@@ -288,14 +292,15 @@ def rectangular_central_neighbors(
288292

289293
# Compute neighbor indices
290294
neighbors[pixel_indices, 0] = pixel_indices - n_cols # Up
291-
neighbors[pixel_indices, 1] = pixel_indices - 1 # Left
292-
neighbors[pixel_indices, 2] = pixel_indices + 1 # Right
295+
neighbors[pixel_indices, 1] = pixel_indices - 1 # Left
296+
neighbors[pixel_indices, 2] = pixel_indices + 1 # Right
293297
neighbors[pixel_indices, 3] = pixel_indices + n_cols # Down
294298

295299
neighbors_sizes[pixel_indices] = 4
296300

297301
return neighbors, neighbors_sizes
298302

303+
299304
def rectangular_edge_pixel_list_from(shape_native: Tuple[int, int]) -> List[int]:
300305
"""
301306
Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization,

autoarray/inversion/regularization/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,3 @@
1010
from .gaussian_kernel import GaussianKernel
1111
from .exponential_kernel import ExponentialKernel
1212
from .matern_kernel import MaternKernel
13-

autoarray/inversion/regularization/adaptive_brightness.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,19 @@ def weighted_regularization_matrix_from(
8080
coefficient of every source pixel is different.
8181
"""
8282
S, P = neighbors.shape
83-
reg_w = regularization_weights ** 2
83+
reg_w = regularization_weights**2
8484

8585
# 1) Flatten the (i→j) neighbor pairs
86-
I = jnp.repeat(jnp.arange(S), P) # (S*P,)
87-
J = neighbors.reshape(-1) # (S*P,)
86+
I = jnp.repeat(jnp.arange(S), P) # (S*P,)
87+
J = neighbors.reshape(-1) # (S*P,)
8888

8989
# 2) Remap “no neighbor” entries to an extra slot S, whose weight=0
9090
OUT = S
9191
J = jnp.where(J < 0, OUT, J)
9292

9393
# 3) Build an extended weight vector with a zero at index S
9494
reg_w_ext = jnp.concatenate([reg_w, jnp.zeros((1,))], axis=0)
95-
w_ij = reg_w_ext[J] # (S*P,)
95+
w_ij = reg_w_ext[J] # (S*P,)
9696

9797
# 4) Start with zeros on an (S+1)x(S+1) canvas so we can scatter into row S safely
9898
mat = jnp.zeros((S + 1, S + 1), dtype=regularization_weights.dtype)
@@ -102,10 +102,9 @@ def weighted_regularization_matrix_from(
102102
# - sum_j reg_w[j] into diag[i]
103103
# - sum contributions reg_w[j] into diag[j]
104104
# (diagonal at OUT=S picks up zeros only)
105-
diag_updates_i = jnp.concatenate([
106-
jnp.full((S,), 1e-8),
107-
jnp.zeros((1,)) # out‐of‐bounds slot stays zero
108-
], axis=0)
105+
diag_updates_i = jnp.concatenate(
106+
[jnp.full((S,), 1e-8), jnp.zeros((1,))], axis=0 # out‐of‐bounds slot stays zero
107+
)
109108
mat = mat.at[jnp.diag_indices(S + 1)].add(diag_updates_i)
110109
mat = mat.at[I, I].add(w_ij)
111110
mat = mat.at[J, J].add(w_ij)
@@ -117,6 +116,7 @@ def weighted_regularization_matrix_from(
117116
# 7) Drop the extra row/column S and return the S×S result
118117
return mat[:S, :S]
119118

119+
120120
class AdaptiveBrightness(AbstractRegularization):
121121
def __init__(
122122
self,

autoarray/inversion/regularization/brightness_zeroth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def brightness_zeroth_regularization_weights_from(
4242
"""
4343
return coefficient * (1.0 - pixel_signals)
4444

45+
4546
def brightness_zeroth_regularization_matrix_from(
4647
regularization_weights: jnp.ndarray,
4748
) -> jnp.ndarray:
@@ -63,7 +64,6 @@ def brightness_zeroth_regularization_matrix_from(
6364
return jnp.diag(regularization_weight_squared)
6465

6566

66-
6767
class BrightnessZeroth(AbstractRegularization):
6868
def __init__(
6969
self,

autoarray/inversion/regularization/constant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ class in the module `autoarray.inversion.regularization`.
5252
# This ensures that JAX can efficiently drop these entries during matrix updates.
5353
neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors)
5454
return (
55-
jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors]
55+
jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[
56+
I_IDX, neighbors
57+
]
5658
# unique indices should be guranteed by neighbors-spec
5759
.add(-regularization_coefficient, mode="drop", unique_indices=True)
5860
)

autoarray/inversion/regularization/constant_zeroth.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ class in the module ``autoarray.inversion.regularization``.
5050
# This ensures that JAX can efficiently drop these entries during matrix updates.
5151
neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors)
5252
const = (
53-
jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors]
53+
jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[
54+
I_IDX, neighbors
55+
]
5456
# unique indices should be guranteed by neighbors-spec
5557
.add(-regularization_coefficient, mode="drop", unique_indices=True)
5658
)
5759

58-
reg_coeff = coefficient_zeroth ** 2.0
60+
reg_coeff = coefficient_zeroth**2.0
5961
# Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T
6062
zeroth = jnp.eye(P) * reg_coeff
6163

autoarray/inversion/regularization/exponential_kernel.py

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,50 @@
11
from __future__ import annotations
2-
import numpy as np
2+
import jax.numpy as jnp
33
from typing import TYPE_CHECKING
44

55
if TYPE_CHECKING:
66
from autoarray.inversion.linear_obj.linear_obj import LinearObj
77

88
from autoarray.inversion.regularization.abstract import AbstractRegularization
99

10-
from autoarray import numba_util
1110

12-
13-
@numba_util.jit()
1411
def exp_cov_matrix_from(
1512
scale: float,
16-
pixel_points: np.ndarray,
17-
) -> np.ndarray:
13+
pixel_points: jnp.ndarray, # shape (N, 2)
14+
) -> jnp.ndarray: # shape (N, N)
1815
"""
19-
Consutruct the source brightness covariance matrix, which is used to determined the regularization
20-
pattern (i.e, how the different source pixels are smoothed).
16+
Construct the source brightness covariance matrix using an exponential kernel:
17+
18+
cov[i,j] = exp(- d_{ij} / scale)
2119
22-
The covariance matrix includes one non-linear parameters, the scale coefficient, which is used to determine
23-
the typical scale of the regularization pattern.
20+
with a tiny jitter 1e-8 added on the diagonal for numerical stability.
2421
2522
Parameters
2623
----------
2724
scale
28-
The typical scale of the regularization pattern .
25+
The length‐scale of the exponential kernel.
2926
pixel_points
30-
An 2d array with shape [N_source_pixels, 2], which save the source pixelization coordinates (on source plane).
31-
Something like [[y1,x1], [y2,x2], ...]
27+
Array of shape (N, 2) giving the (y,x) coordinates of each source‐plane pixel.
3228
3329
Returns
3430
-------
35-
np.ndarray
36-
The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels].
31+
jnp.ndarray, shape (N, N)
32+
The exponential covariance matrix.
3733
"""
34+
# pairwise differences: shape (N, N, 2)
35+
diff = pixel_points[:, None, :] - pixel_points[None, :, :]
3836

39-
pixels = len(pixel_points)
40-
covariance_matrix = np.zeros(shape=(pixels, pixels))
37+
# Euclidean distances: shape (N, N)
38+
d = jnp.linalg.norm(diff, axis=-1)
4139

42-
for i in range(pixels):
43-
covariance_matrix[i, i] += 1e-8
44-
for j in range(pixels):
45-
xi = pixel_points[i, 1]
46-
yi = pixel_points[i, 0]
47-
xj = pixel_points[j, 1]
48-
yj = pixel_points[j, 0]
49-
d_ij = np.sqrt(
50-
(xi - xj) ** 2 + (yi - yj) ** 2
51-
) # distance between the pixel i and j
40+
# exponential kernel
41+
cov = jnp.exp(-d / scale)
5242

53-
covariance_matrix[i, j] += np.exp(-1.0 * d_ij / scale)
43+
# add a small jitter on the diagonal
44+
N = pixel_points.shape[0]
45+
cov = cov + jnp.eye(N) * 1e-8
5446

55-
return covariance_matrix
47+
return cov
5648

5749

5850
class ExponentialKernel(AbstractRegularization):
@@ -83,7 +75,7 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0):
8375

8476
super().__init__()
8577

86-
def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
78+
def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray:
8779
"""
8880
Returns the regularization weights of this regularization scheme.
8981
@@ -102,9 +94,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
10294
-------
10395
The regularization weights.
10496
"""
105-
return self.coefficient * np.ones(linear_obj.params)
97+
return self.coefficient * jnp.ones(linear_obj.params)
10698

107-
def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
99+
def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray:
108100
"""
109101
Returns the regularization matrix with shape [pixels, pixels].
110102
@@ -119,7 +111,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
119111
"""
120112
covariance_matrix = exp_cov_matrix_from(
121113
scale=self.scale,
122-
pixel_points=np.array(linear_obj.source_plane_mesh_grid),
114+
pixel_points=linear_obj.source_plane_mesh_grid.array,
123115
)
124116

125-
return self.coefficient * np.linalg.inv(covariance_matrix)
117+
return self.coefficient * jnp.linalg.inv(covariance_matrix)

autoarray/inversion/regularization/gaussian_kernel.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from autoarray.inversion.regularization.abstract import AbstractRegularization
1010

11+
1112
def gauss_cov_matrix_from(
1213
scale: float,
1314
pixel_points: jnp.ndarray, # shape (N, 2)
@@ -34,17 +35,17 @@ def gauss_cov_matrix_from(
3435
The Gaussian covariance matrix.
3536
"""
3637
# Ensure array:
37-
pts = jnp.asarray(pixel_points) # (N, 2)
38+
pts = jnp.asarray(pixel_points) # (N, 2)
3839
# Compute squared distances: ||p_i - p_j||^2
39-
diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2)
40-
d2 = jnp.sum(diffs**2, axis=-1) # (N, N)
40+
diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2)
41+
d2 = jnp.sum(diffs**2, axis=-1) # (N, N)
4142

4243
# Gaussian kernel
43-
cov = jnp.exp(-d2 / (2.0 * scale**2)) # (N, N)
44+
cov = jnp.exp(-d2 / (2.0 * scale**2)) # (N, N)
4445

4546
# Add tiny jitter on the diagonal
46-
N = pts.shape[0]
47-
cov = cov + jnp.eye(N, dtype=cov.dtype) * 1e-8
47+
N = pts.shape[0]
48+
cov = cov + jnp.eye(N, dtype=cov.dtype) * 1e-8
4849

4950
return cov
5051

@@ -111,8 +112,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
111112
The regularization matrix.
112113
"""
113114
covariance_matrix = gauss_cov_matrix_from(
114-
scale=self.scale,
115-
pixel_points=linear_obj.source_plane_mesh_grid.array
115+
scale=self.scale, pixel_points=linear_obj.source_plane_mesh_grid.array
116116
)
117117

118118
return self.coefficient * jnp.linalg.inv(covariance_matrix)

autoarray/inversion/regularization/regularization_util.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,24 @@
33

44
from autoarray import exc
55

6-
from autoarray.inversion.regularization.adaptive_brightness import adaptive_regularization_weights_from
7-
from autoarray.inversion.regularization.brightness_zeroth import brightness_zeroth_regularization_matrix_from
8-
from autoarray.inversion.regularization.brightness_zeroth import brightness_zeroth_regularization_weights_from
9-
from autoarray.inversion.regularization.constant import constant_regularization_matrix_from
6+
from autoarray.inversion.regularization.adaptive_brightness import (
7+
adaptive_regularization_weights_from,
8+
)
9+
from autoarray.inversion.regularization.adaptive_brightness import (
10+
weighted_regularization_matrix_from,
11+
)
12+
from autoarray.inversion.regularization.brightness_zeroth import (
13+
brightness_zeroth_regularization_matrix_from,
14+
)
15+
from autoarray.inversion.regularization.brightness_zeroth import (
16+
brightness_zeroth_regularization_weights_from,
17+
)
18+
from autoarray.inversion.regularization.constant import (
19+
constant_regularization_matrix_from,
20+
)
21+
from autoarray.inversion.regularization.constant_zeroth import (
22+
constant_zeroth_regularization_matrix_from,
23+
)
1024
from autoarray.inversion.regularization.exponential_kernel import exp_cov_matrix_from
1125
from autoarray.inversion.regularization.gaussian_kernel import gauss_cov_matrix_from
1226
from autoarray.inversion.regularization.matern_kernel import matern_kernel

autoarray/inversion/regularization/zeroth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class in the module `autoarray.inversion.regularization`.
3030
coefficient of every source pixel is the same.
3131
"""
3232

33-
reg_coeff = coefficient ** 2.0
33+
reg_coeff = coefficient**2.0
3434

3535
# Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T
3636

0 commit comments

Comments
 (0)