11from __future__ import annotations
2- import numpy as np
2+ import jax . numpy as jnp
33from typing import TYPE_CHECKING
44
55if TYPE_CHECKING :
66 from autoarray .inversion .linear_obj .linear_obj import LinearObj
77
88from autoarray .inversion .regularization .abstract import AbstractRegularization
99
10- from autoarray .inversion .regularization import regularization_util
1110
11+ def adaptive_regularization_weights_from (
12+ inner_coefficient : float , outer_coefficient : float , pixel_signals : jnp .ndarray
13+ ) -> jnp .ndarray :
14+ """
15+ Returns the regularization weights for the adaptive regularization scheme (e.g. ``AdaptiveBrightness``).
16+
17+ The weights define the effective regularization coefficient of every mesh parameter (typically pixels
18+ of a ``Mapper``).
19+
20+ They are computed using an estimate of the expected signal in each pixel.
21+
22+ Two regularization coefficients are used, corresponding to the:
23+
24+ 1) pixel_signals: pixels with a high pixel-signal (i.e. where the signal is located in the pixelization).
25+ 2) 1.0 - pixel_signals: pixels with a low pixel-signal (i.e. where the signal is not located in the pixelization).
26+
27+ Parameters
28+ ----------
29+ inner_coefficient
30+ The inner regularization coefficients which controls the degree of smoothing of the inversion reconstruction
31+ in the inner regions of a mesh's reconstruction.
32+ outer_coefficient
33+ The outer regularization coefficients which controls the degree of smoothing of the inversion reconstruction
34+ in the outer regions of a mesh's reconstruction.
35+ pixel_signals
36+ The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal
37+ and low signal pixelizations.
38+
39+ Returns
40+ -------
41+ jnp.ndarray
42+ The adaptive regularization weights which act as the effective regularization coefficients of
43+ every source pixel.
44+ """
45+ return (
46+ inner_coefficient * pixel_signals + outer_coefficient * (1.0 - pixel_signals )
47+ ) ** 2.0
48+
49+
50+ def weighted_regularization_matrix_from (
51+ regularization_weights : jnp .ndarray ,
52+ neighbors : jnp .ndarray ,
53+ ) -> jnp .ndarray :
54+ """
55+ Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``).
56+
57+ This matrix is computed using the regularization weights of every mesh pixel, which are computed using the
58+ function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of
59+ every mesh pixel.
60+
61+ The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate
62+ neighbor calculation of the corresponding ``Mapper`` class.
63+
64+ Parameters
65+ ----------
66+ regularization_weights
67+ The regularization weight of each pixel, adaptively governing the degree of gradient regularization
68+ applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``).
69+ neighbors
70+ An array of length (total_pixels) which provides the index of all neighbors of every pixel in
71+ the mesh grid (entries of -1 correspond to no neighbor).
72+ neighbors_sizes
73+ An array of length (total_pixels) which gives the number of neighbors of every pixel in the
74+ Voronoi grid.
75+
76+ Returns
77+ -------
78+ jnp.ndarray
79+ The regularization matrix computed using an adaptive regularization scheme where the effective regularization
80+ coefficient of every source pixel is different.
81+ """
82+ S , P = neighbors .shape
83+ reg_w = regularization_weights ** 2
84+
85+ # 1) Flatten the (i→j) neighbor pairs
86+ I = jnp .repeat (jnp .arange (S ), P ) # (S*P,)
87+ J = neighbors .reshape (- 1 ) # (S*P,)
88+
89+ # 2) Remap “no neighbor” entries to an extra slot S, whose weight=0
90+ OUT = S
91+ J = jnp .where (J < 0 , OUT , J )
92+
93+ # 3) Build an extended weight vector with a zero at index S
94+ reg_w_ext = jnp .concatenate ([reg_w , jnp .zeros ((1 ,))], axis = 0 )
95+ w_ij = reg_w_ext [J ] # (S*P,)
96+
97+ # 4) Start with zeros on an (S+1)x(S+1) canvas so we can scatter into row S safely
98+ mat = jnp .zeros ((S + 1 , S + 1 ), dtype = regularization_weights .dtype )
99+
100+ # 5) Scatter into the diagonal:
101+ # - the tiny 1e-8 floor on each i < S
102+ # - sum_j reg_w[j] into diag[i]
103+ # - sum contributions reg_w[j] into diag[j]
104+ # (diagonal at OUT=S picks up zeros only)
105+ diag_updates_i = jnp .concatenate ([
106+ jnp .full ((S ,), 1e-8 ),
107+ jnp .zeros ((1 ,)) # out‐of‐bounds slot stays zero
108+ ], axis = 0 )
109+ mat = mat .at [jnp .diag_indices (S + 1 )].add (diag_updates_i )
110+ mat = mat .at [I , I ].add (w_ij )
111+ mat = mat .at [J , J ].add (w_ij )
112+
113+ # 6) Scatter the off‐diagonal subtractions:
114+ mat = mat .at [I , J ].add (- w_ij )
115+ mat = mat .at [J , I ].add (- w_ij )
116+
117+ # 7) Drop the extra row/column S and return the S×S result
118+ return mat [:S , :S ]
12119
13120class AdaptiveBrightness (AbstractRegularization ):
14121 def __init__ (
@@ -70,7 +177,7 @@ def __init__(
70177 self .outer_coefficient = outer_coefficient
71178 self .signal_scale = signal_scale
72179
73- def regularization_weights_from (self , linear_obj : LinearObj ) -> np .ndarray :
180+ def regularization_weights_from (self , linear_obj : LinearObj ) -> jnp .ndarray :
74181 """
75182 Returns the regularization weights of this regularization scheme.
76183
@@ -91,13 +198,13 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
91198 """
92199 pixel_signals = linear_obj .pixel_signals_from (signal_scale = self .signal_scale )
93200
94- return regularization_util . adaptive_regularization_weights_from (
201+ return adaptive_regularization_weights_from (
95202 inner_coefficient = self .inner_coefficient ,
96203 outer_coefficient = self .outer_coefficient ,
97204 pixel_signals = pixel_signals ,
98205 )
99206
100- def regularization_matrix_from (self , linear_obj : LinearObj ) -> np .ndarray :
207+ def regularization_matrix_from (self , linear_obj : LinearObj ) -> jnp .ndarray :
101208 """
102209 Returns the regularization matrix with shape [pixels, pixels].
103210
@@ -112,7 +219,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
112219 """
113220 regularization_weights = self .regularization_weights_from (linear_obj = linear_obj )
114221
115- return regularization_util . weighted_regularization_matrix_from (
222+ return weighted_regularization_matrix_from (
116223 regularization_weights = regularization_weights ,
117224 neighbors = linear_obj .source_plane_mesh_grid .neighbors ,
118225 )
0 commit comments