Skip to content

Commit 1d6043b

Browse files
Jammy2211Jammy2211
authored andcommitted
regularization util JAX conversons, seem to work
1 parent 5d94e9a commit 1d6043b

10 files changed

Lines changed: 166 additions & 219 deletions

File tree

autoarray/inversion/inversion/abstract.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import jax.numpy as jnp
3+
from jax.scipy.linalg import block_diag
34
import numpy as np
45

56
from typing import Dict, List, Optional, Type, Union
@@ -334,8 +335,6 @@ def regularization_matrix(self) -> Optional[np.ndarray]:
334335
If the `settings.force_edge_pixels_to_zeros` is `True`, the edge pixels of each mapper in the inversion
335336
are regularized so high their value is forced to zero.
336337
"""
337-
from scipy.linalg import block_diag
338-
339338
return block_diag(
340339
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
341340
)
@@ -664,30 +663,14 @@ def log_det_regularization_matrix_term(self) -> float:
664663
float
665664
The log determinant of the regularization matrix.
666665
"""
667-
from scipy.sparse import csc_matrix
668-
from scipy.sparse.linalg import splu
669-
670-
if not self.has(cls=AbstractRegularization):
671-
return 0.0
672-
673666
try:
674-
lu = splu(csc_matrix(self.regularization_matrix_reduced))
675-
diagL = lu.L.diagonal()
676-
diagU = lu.U.diagonal()
677-
diagL = diagL.astype(np.complex128)
678-
diagU = diagU.astype(np.complex128)
679-
680-
return np.real(np.log(diagL).sum() + np.log(diagU).sum())
681-
682-
except RuntimeError:
683-
try:
684-
return 2.0 * np.sum(
685-
np.log(
686-
np.diag(np.linalg.cholesky(self.regularization_matrix_reduced))
687-
)
667+
return 2.0 * np.sum(
668+
jnp.log(
669+
jnp.diag(jnp.linalg.cholesky(self.regularization_matrix_reduced))
688670
)
689-
except np.linalg.LinAlgError as e:
690-
raise exc.InversionException() from e
671+
)
672+
except np.linalg.LinAlgError as e:
673+
raise exc.InversionException() from e
691674

692675
@property
693676
def reconstruction_noise_map_with_covariance(self) -> np.ndarray:

autoarray/inversion/pixelization/mappers/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def pixel_signals_from(self, signal_scale: float) -> np.ndarray:
288288
pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index,
289289
pix_size_for_sub_slim_index=self.pix_sizes_for_sub_slim_index,
290290
slim_index_for_sub_slim_index=self.over_sampler.slim_for_sub_slim,
291-
adapt_data=np.array(self.adapt_data),
291+
adapt_data=self.adapt_data.array,
292292
)
293293

294294
def slim_indexes_for_pix_indexes(self, pix_indexes: List) -> List[List]:

autoarray/inversion/pixelization/mesh/mesh_util.py

Lines changed: 81 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from autoarray import numba_util
66

77

8-
@numba_util.jit()
98
def rectangular_neighbors_from(
109
shape_native: Tuple[int, int],
1110
) -> Tuple[np.ndarray, np.ndarray]:
@@ -68,7 +67,6 @@ def rectangular_neighbors_from(
6867
return neighbors, neighbors_sizes
6968

7069

71-
@numba_util.jit()
7270
def rectangular_corner_neighbors(
7371
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
7472
) -> Tuple[np.ndarray, np.ndarray]:
@@ -112,8 +110,6 @@ def rectangular_corner_neighbors(
112110

113111
return neighbors, neighbors_sizes
114112

115-
116-
@numba_util.jit()
117113
def rectangular_top_edge_neighbors(
118114
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
119115
) -> Tuple[np.ndarray, np.ndarray]:
@@ -136,17 +132,19 @@ def rectangular_top_edge_neighbors(
136132
-------
137133
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
138134
"""
139-
for pix in range(1, shape_native[1] - 1):
140-
pixel_index = pix
141-
neighbors[pixel_index, 0:3] = np.array(
142-
[pixel_index - 1, pixel_index + 1, pixel_index + shape_native[1]]
143-
)
144-
neighbors_sizes[pixel_index] = 3
135+
"""
136+
Vectorized version of the top edge neighbor update using NumPy arithmetic.
137+
"""
138+
# Pixels along the top edge, excluding corners
139+
top_edge_pixels = np.arange(1, shape_native[1] - 1)
145140

146-
return neighbors, neighbors_sizes
141+
neighbors[top_edge_pixels, 0] = top_edge_pixels - 1
142+
neighbors[top_edge_pixels, 1] = top_edge_pixels + 1
143+
neighbors[top_edge_pixels, 2] = top_edge_pixels + shape_native[1]
144+
neighbors_sizes[top_edge_pixels] = 3
147145

146+
return neighbors, neighbors_sizes
148147

149-
@numba_util.jit()
150148
def rectangular_left_edge_neighbors(
151149
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
152150
) -> Tuple[np.ndarray, np.ndarray]:
@@ -169,21 +167,19 @@ def rectangular_left_edge_neighbors(
169167
-------
170168
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
171169
"""
172-
for pix in range(1, shape_native[0] - 1):
173-
pixel_index = pix * shape_native[1]
174-
neighbors[pixel_index, 0:3] = np.array(
175-
[
176-
pixel_index - shape_native[1],
177-
pixel_index + 1,
178-
pixel_index + shape_native[1],
179-
]
180-
)
181-
neighbors_sizes[pixel_index] = 3
170+
# Row indices (excluding top and bottom corners)
171+
rows = np.arange(1, shape_native[0] - 1)
182172

183-
return neighbors, neighbors_sizes
173+
# Convert to flat pixel indices for the left edge (first column)
174+
pixel_indices = rows * shape_native[1]
184175

176+
neighbors[pixel_indices, 0] = pixel_indices - shape_native[1]
177+
neighbors[pixel_indices, 1] = pixel_indices + 1
178+
neighbors[pixel_indices, 2] = pixel_indices + shape_native[1]
179+
neighbors_sizes[pixel_indices] = 3
180+
181+
return neighbors, neighbors_sizes
185182

186-
@numba_util.jit()
187183
def rectangular_right_edge_neighbors(
188184
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
189185
) -> Tuple[np.ndarray, np.ndarray]:
@@ -206,21 +202,19 @@ def rectangular_right_edge_neighbors(
206202
-------
207203
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
208204
"""
209-
for pix in range(1, shape_native[0] - 1):
210-
pixel_index = pix * shape_native[1] + shape_native[1] - 1
211-
neighbors[pixel_index, 0:3] = np.array(
212-
[
213-
pixel_index - shape_native[1],
214-
pixel_index - 1,
215-
pixel_index + shape_native[1],
216-
]
217-
)
218-
neighbors_sizes[pixel_index] = 3
205+
# Rows excluding the top and bottom corners
206+
rows = np.arange(1, shape_native[0] - 1)
219207

220-
return neighbors, neighbors_sizes
208+
# Flat indices for the right edge pixels
209+
pixel_indices = rows * shape_native[1] + shape_native[1] - 1
221210

211+
neighbors[pixel_indices, 0] = pixel_indices - shape_native[1]
212+
neighbors[pixel_indices, 1] = pixel_indices - 1
213+
neighbors[pixel_indices, 2] = pixel_indices + shape_native[1]
214+
neighbors_sizes[pixel_indices] = 3
215+
216+
return neighbors, neighbors_sizes
222217

223-
@numba_util.jit()
224218
def rectangular_bottom_edge_neighbors(
225219
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
226220
) -> Tuple[np.ndarray, np.ndarray]:
@@ -243,19 +237,21 @@ def rectangular_bottom_edge_neighbors(
243237
-------
244238
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
245239
"""
246-
pixels = int(shape_native[0] * shape_native[1])
240+
n_rows, n_cols = shape_native
241+
pixels = n_rows * n_cols
242+
243+
# Horizontal pixel positions along bottom row, excluding corners
244+
cols = np.arange(1, n_cols - 1)
245+
pixel_indices = pixels - cols - 1 # Reverse order from right to left
247246

248-
for pix in range(1, shape_native[1] - 1):
249-
pixel_index = pixels - pix - 1
250-
neighbors[pixel_index, 0:3] = np.array(
251-
[pixel_index - shape_native[1], pixel_index - 1, pixel_index + 1]
252-
)
253-
neighbors_sizes[pixel_index] = 3
247+
neighbors[pixel_indices, 0] = pixel_indices - n_cols
248+
neighbors[pixel_indices, 1] = pixel_indices - 1
249+
neighbors[pixel_indices, 2] = pixel_indices + 1
250+
neighbors_sizes[pixel_indices] = 3
254251

255252
return neighbors, neighbors_sizes
256253

257254

258-
@numba_util.jit()
259255
def rectangular_central_neighbors(
260256
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
261257
) -> Tuple[np.ndarray, np.ndarray]:
@@ -279,46 +275,60 @@ def rectangular_central_neighbors(
279275
-------
280276
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
281277
"""
282-
for x in range(1, shape_native[0] - 1):
283-
for y in range(1, shape_native[1] - 1):
284-
pixel_index = x * shape_native[1] + y
285-
neighbors[pixel_index, 0:4] = np.array(
286-
[
287-
pixel_index - shape_native[1],
288-
pixel_index - 1,
289-
pixel_index + 1,
290-
pixel_index + shape_native[1],
291-
]
292-
)
293-
neighbors_sizes[pixel_index] = 4
278+
n_rows, n_cols = shape_native
294279

295-
return neighbors, neighbors_sizes
280+
# Grid coordinates excluding edges
281+
xs = np.arange(1, n_rows - 1)
282+
ys = np.arange(1, n_cols - 1)
296283

284+
# 2D grid of central pixel indices
285+
grid_x, grid_y = np.meshgrid(xs, ys, indexing="ij")
286+
pixel_indices = grid_x * n_cols + grid_y
287+
pixel_indices = pixel_indices.ravel()
297288

298-
def rectangular_edge_pixel_list_from(neighbors: np.ndarray) -> List:
299-
"""
300-
Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization.
289+
# Compute neighbor indices
290+
neighbors[pixel_indices, 0] = pixel_indices - n_cols # Up
291+
neighbors[pixel_indices, 1] = pixel_indices - 1 # Left
292+
neighbors[pixel_indices, 2] = pixel_indices + 1 # Right
293+
neighbors[pixel_indices, 3] = pixel_indices + n_cols # Down
294+
295+
neighbors_sizes[pixel_indices] = 4
296+
297+
return neighbors, neighbors_sizes
301298

302-
This is computed by searching the `neighbors` array for pixels that have a neighbor with index -1, meaning there
303-
is at least one neighbor from the 4 expected missing.
299+
def rectangular_edge_pixel_list_from(shape_native: Tuple[int, int]) -> List[int]:
300+
"""
301+
Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization,
302+
based on its 2D shape.
304303
305304
Parameters
306305
----------
307-
neighbors
308-
An array of dimensions [total_pixels, 4] which provides the index of all neighbors of every pixel in the
309-
rectangular pixelization (entries of -1 correspond to no neighbor).
306+
shape_native
307+
The (rows, cols) shape of the rectangular 2D pixel grid.
310308
311309
Returns
312310
-------
313-
A list of the 1D indices of all pixels on the edge of a rectangular pixelization.
311+
A list of the 1D indices of all edge pixels.
314312
"""
315-
edge_pixel_list = []
313+
rows, cols = shape_native
314+
315+
# Top row
316+
top = np.arange(0, cols)
317+
318+
# Bottom row
319+
bottom = np.arange((rows - 1) * cols, rows * cols)
320+
321+
# Left column (excluding corners)
322+
left = np.arange(1, rows - 1) * cols
323+
324+
# Right column (excluding corners)
325+
right = (np.arange(1, rows - 1) + 1) * cols - 1
316326

317-
for i, neighbors in enumerate(neighbors):
318-
if -1 in neighbors:
319-
edge_pixel_list.append(i)
327+
# Concatenate all edge indices
328+
edge_pixel_indices = np.concatenate([top, left, right, bottom])
320329

321-
return edge_pixel_list
330+
# Sort and return
331+
return np.sort(edge_pixel_indices).tolist()
322332

323333

324334
@numba_util.jit()

autoarray/inversion/regularization/adaptive_brightness.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,4 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
115115
return regularization_util.weighted_regularization_matrix_from(
116116
regularization_weights=regularization_weights,
117117
neighbors=linear_obj.source_plane_mesh_grid.neighbors,
118-
neighbors_sizes=linear_obj.source_plane_mesh_grid.neighbors.sizes,
119118
)

0 commit comments

Comments
 (0)