Skip to content

Commit 150a360

Browse files
Jammy2211Jammy2211
authored andcommitted
move utils to their specific modules
1 parent a509014 commit 150a360

7 files changed

Lines changed: 325 additions & 321 deletions

File tree

autoarray/inversion/regularization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
from .gaussian_kernel import GaussianKernel
1111
from .exponential_kernel import ExponentialKernel
1212
from .matern_kernel import MaternKernel
13+

autoarray/inversion/regularization/adaptive_brightness.py

Lines changed: 113 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,121 @@
11
from __future__ import annotations
2-
import numpy as np
2+
import jax.numpy as jnp
33
from typing import TYPE_CHECKING
44

55
if TYPE_CHECKING:
66
from autoarray.inversion.linear_obj.linear_obj import LinearObj
77

88
from autoarray.inversion.regularization.abstract import AbstractRegularization
99

10-
from autoarray.inversion.regularization import regularization_util
1110

11+
def adaptive_regularization_weights_from(
12+
inner_coefficient: float, outer_coefficient: float, pixel_signals: jnp.ndarray
13+
) -> jnp.ndarray:
14+
"""
15+
Returns the regularization weights for the adaptive regularization scheme (e.g. ``AdaptiveBrightness``).
16+
17+
The weights define the effective regularization coefficient of every mesh parameter (typically pixels
18+
of a ``Mapper``).
19+
20+
They are computed using an estimate of the expected signal in each pixel.
21+
22+
Two regularization coefficients are used, corresponding to the:
23+
24+
1) pixel_signals: pixels with a high pixel-signal (i.e. where the signal is located in the pixelization).
25+
2) 1.0 - pixel_signals: pixels with a low pixel-signal (i.e. where the signal is not located in the pixelization).
26+
27+
Parameters
28+
----------
29+
inner_coefficient
30+
The inner regularization coefficients which controls the degree of smoothing of the inversion reconstruction
31+
in the inner regions of a mesh's reconstruction.
32+
outer_coefficient
33+
The outer regularization coefficients which controls the degree of smoothing of the inversion reconstruction
34+
in the outer regions of a mesh's reconstruction.
35+
pixel_signals
36+
The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal
37+
and low signal pixelizations.
38+
39+
Returns
40+
-------
41+
jnp.ndarray
42+
The adaptive regularization weights which act as the effective regularization coefficients of
43+
every source pixel.
44+
"""
45+
return (
46+
inner_coefficient * pixel_signals + outer_coefficient * (1.0 - pixel_signals)
47+
) ** 2.0
48+
49+
50+
def weighted_regularization_matrix_from(
51+
regularization_weights: jnp.ndarray,
52+
neighbors: jnp.ndarray,
53+
) -> jnp.ndarray:
54+
"""
55+
Returns the regularization matrix of the adaptive regularization scheme (e.g. ``AdaptiveBrightness``).
56+
57+
This matrix is computed using the regularization weights of every mesh pixel, which are computed using the
58+
function ``adaptive_regularization_weights_from``. These act as the effective regularization coefficients of
59+
every mesh pixel.
60+
61+
The regularization matrix is computed using the pixel-neighbors array, which is setup using the appropriate
62+
neighbor calculation of the corresponding ``Mapper`` class.
63+
64+
Parameters
65+
----------
66+
regularization_weights
67+
The regularization weight of each pixel, adaptively governing the degree of gradient regularization
68+
applied to each inversion parameter (e.g. mesh pixels of a ``Mapper``).
69+
neighbors
70+
An array of length (total_pixels) which provides the index of all neighbors of every pixel in
71+
the mesh grid (entries of -1 correspond to no neighbor).
72+
neighbors_sizes
73+
An array of length (total_pixels) which gives the number of neighbors of every pixel in the
74+
Voronoi grid.
75+
76+
Returns
77+
-------
78+
jnp.ndarray
79+
The regularization matrix computed using an adaptive regularization scheme where the effective regularization
80+
coefficient of every source pixel is different.
81+
"""
82+
S, P = neighbors.shape
83+
reg_w = regularization_weights ** 2
84+
85+
# 1) Flatten the (i→j) neighbor pairs
86+
I = jnp.repeat(jnp.arange(S), P) # (S*P,)
87+
J = neighbors.reshape(-1) # (S*P,)
88+
89+
# 2) Remap “no neighbor” entries to an extra slot S, whose weight=0
90+
OUT = S
91+
J = jnp.where(J < 0, OUT, J)
92+
93+
# 3) Build an extended weight vector with a zero at index S
94+
reg_w_ext = jnp.concatenate([reg_w, jnp.zeros((1,))], axis=0)
95+
w_ij = reg_w_ext[J] # (S*P,)
96+
97+
# 4) Start with zeros on an (S+1)x(S+1) canvas so we can scatter into row S safely
98+
mat = jnp.zeros((S + 1, S + 1), dtype=regularization_weights.dtype)
99+
100+
# 5) Scatter into the diagonal:
101+
# - the tiny 1e-8 floor on each i < S
102+
# - sum_j reg_w[j] into diag[i]
103+
# - sum contributions reg_w[j] into diag[j]
104+
# (diagonal at OUT=S picks up zeros only)
105+
diag_updates_i = jnp.concatenate([
106+
jnp.full((S,), 1e-8),
107+
jnp.zeros((1,)) # out‐of‐bounds slot stays zero
108+
], axis=0)
109+
mat = mat.at[jnp.diag_indices(S + 1)].add(diag_updates_i)
110+
mat = mat.at[I, I].add(w_ij)
111+
mat = mat.at[J, J].add(w_ij)
112+
113+
# 6) Scatter the off‐diagonal subtractions:
114+
mat = mat.at[I, J].add(-w_ij)
115+
mat = mat.at[J, I].add(-w_ij)
116+
117+
# 7) Drop the extra row/column S and return the S×S result
118+
return mat[:S, :S]
12119

13120
class AdaptiveBrightness(AbstractRegularization):
14121
def __init__(
@@ -70,7 +177,7 @@ def __init__(
70177
self.outer_coefficient = outer_coefficient
71178
self.signal_scale = signal_scale
72179

73-
def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
180+
def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray:
74181
"""
75182
Returns the regularization weights of this regularization scheme.
76183
@@ -91,13 +198,13 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
91198
"""
92199
pixel_signals = linear_obj.pixel_signals_from(signal_scale=self.signal_scale)
93200

94-
return regularization_util.adaptive_regularization_weights_from(
201+
return adaptive_regularization_weights_from(
95202
inner_coefficient=self.inner_coefficient,
96203
outer_coefficient=self.outer_coefficient,
97204
pixel_signals=pixel_signals,
98205
)
99206

100-
def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
207+
def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray:
101208
"""
102209
Returns the regularization matrix with shape [pixels, pixels].
103210
@@ -112,7 +219,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
112219
"""
113220
regularization_weights = self.regularization_weights_from(linear_obj=linear_obj)
114221

115-
return regularization_util.weighted_regularization_matrix_from(
222+
return weighted_regularization_matrix_from(
116223
regularization_weights=regularization_weights,
117224
neighbors=linear_obj.source_plane_mesh_grid.neighbors,
118225
)

autoarray/inversion/regularization/brightness_zeroth.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
import numpy as np
2+
import jax.numpy as jnp
33
from typing import TYPE_CHECKING
44

55
if TYPE_CHECKING:
@@ -10,6 +10,60 @@
1010
from autoarray.inversion.regularization import regularization_util
1111

1212

13+
def brightness_zeroth_regularization_weights_from(
14+
coefficient: float, pixel_signals: jnp.ndarray
15+
) -> jnp.ndarray:
16+
"""
17+
Returns the regularization weights for the brightness zeroth regularization scheme (e.g. ``BrightnessZeroth``).
18+
19+
The weights define the level of zeroth order regularization applied to every mesh parameter (typically pixels
20+
of a ``Mapper``).
21+
22+
They are computed using an estimate of the expected signal in each pixel.
23+
24+
The zeroth order regularization coefficients is applied in combination with 1.0 - pixel_signals, which are
25+
the pixels with a low pixel-signal (i.e. where the signal is not located near the source being reconstructed in
26+
the pixelization).
27+
28+
Parameters
29+
----------
30+
coefficient
31+
The level of zeroth order regularization applied to every mesh parameter (typically pixels of a ``Mapper``),
32+
with the degree applied varying based on the ``pixel_signals``.
33+
pixel_signals
34+
The estimated signal in every pixelization pixel, used to change the regularization weighting of high signal
35+
and low signal pixelizations.
36+
37+
Returns
38+
-------
39+
jnp.ndarray
40+
The zeroth order regularization weights which act as the effective level of zeroth order regularization
41+
applied to every mesh parameter.
42+
"""
43+
return coefficient * (1.0 - pixel_signals)
44+
45+
def brightness_zeroth_regularization_matrix_from(
46+
regularization_weights: jnp.ndarray,
47+
) -> jnp.ndarray:
48+
"""
49+
Returns the regularization matrix for the zeroth-order brightness regularization scheme.
50+
51+
Parameters
52+
----------
53+
regularization_weights
54+
The regularization weights for each pixel, governing the strength of zeroth-order
55+
regularization applied per inversion parameter.
56+
57+
Returns
58+
-------
59+
A diagonal regularization matrix where each diagonal element is the squared regularization weight
60+
for that pixel.
61+
"""
62+
regularization_weight_squared = regularization_weights**2.0
63+
return jnp.diag(regularization_weight_squared)
64+
65+
66+
1367
class BrightnessZeroth(AbstractRegularization):
1468
def __init__(
1569
self,
@@ -45,7 +99,7 @@ def __init__(
4599
self.coefficient = coefficient
46100
self.signal_scale = signal_scale
47101

48-
def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
102+
def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray:
49103
"""
50104
Returns the regularization weights of the ``BrightnessZeroth`` regularization scheme.
51105
@@ -65,11 +119,11 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
65119
"""
66120
pixel_signals = linear_obj.pixel_signals_from(signal_scale=self.signal_scale)
67121

68-
return regularization_util.brightness_zeroth_regularization_weights_from(
122+
return brightness_zeroth_regularization_weights_from(
69123
coefficient=self.coefficient, pixel_signals=pixel_signals
70124
)
71125

72-
def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
126+
def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray:
73127
"""
74128
Returns the regularization matrix with shape [pixels, pixels].
75129
@@ -84,6 +138,6 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
84138
"""
85139
regularization_weights = self.regularization_weights_from(linear_obj=linear_obj)
86140

87-
return regularization_util.brightness_zeroth_regularization_matrix_from(
141+
return brightness_zeroth_regularization_matrix_from(
88142
regularization_weights=regularization_weights
89143
)

autoarray/inversion/regularization/constant.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,61 @@
11
from __future__ import annotations
2-
import numpy as np
2+
import jax.numpy as jnp
33
from typing import TYPE_CHECKING
44

55
if TYPE_CHECKING:
66
from autoarray.inversion.linear_obj.linear_obj import LinearObj
77

88
from autoarray.inversion.regularization.abstract import AbstractRegularization
99

10-
from autoarray.inversion.regularization import regularization_util
10+
11+
def constant_regularization_matrix_from(
12+
coefficient: float,
13+
neighbors: jnp.ndarray[[int, int], jnp.int64],
14+
neighbors_sizes: jnp.ndarray[[int], jnp.int64],
15+
) -> jnp.ndarray[[int, int], jnp.float64]:
16+
"""
17+
From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme.
18+
19+
A complete description of regularizatin and the `regularization_matrix` can be found in the `Regularization`
20+
class in the module `autoarray.inversion.regularization`.
21+
22+
Memory requirement: 2SP + S^2
23+
FLOPS: 1 + 2S + 2SP
24+
25+
Parameters
26+
----------
27+
coefficient
28+
The regularization coefficients which controls the degree of smoothing of the inversion reconstruction.
29+
neighbors : ndarray, shape (S, P), dtype=int64
30+
An array of length (total_pixels) which provides the index of all neighbors of every pixel in
31+
the Voronoi grid (entries of -1 correspond to no neighbor).
32+
neighbors_sizes : ndarray, shape (S,), dtype=int64
33+
An array of length (total_pixels) which gives the number of neighbors of every pixel in the
34+
Voronoi grid.
35+
36+
Returns
37+
-------
38+
regularization_matrix : ndarray, shape (S, S), dtype=float64
39+
The regularization matrix computed using Regularization where the effective regularization
40+
coefficient of every source pixel is the same.
41+
"""
42+
S, P = neighbors.shape
43+
# as the regularization matrix is S by S, S would be out of bound (any out of bound index would do)
44+
OUT_OF_BOUND_IDX = S
45+
regularization_coefficient = coefficient * coefficient
46+
47+
# flatten it for feeding into the matrix as j indices
48+
neighbors = neighbors.flatten()
49+
# now create the corresponding i indices
50+
I_IDX = jnp.repeat(jnp.arange(S), P)
51+
# Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index.
52+
# This ensures that JAX can efficiently drop these entries during matrix updates.
53+
neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors)
54+
return (
55+
jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors]
56+
# unique indices should be guranteed by neighbors-spec
57+
.add(-regularization_coefficient, mode="drop", unique_indices=True)
58+
)
1159

1260

1361
class Constant(AbstractRegularization):
@@ -38,7 +86,7 @@ def __init__(self, coefficient: float = 1.0):
3886

3987
super().__init__()
4088

41-
def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
89+
def regularization_weights_from(self, linear_obj: LinearObj) -> jnp.ndarray:
4290
"""
4391
Returns the regularization weights of this regularization scheme.
4492
@@ -57,9 +105,9 @@ def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
57105
-------
58106
The regularization weights.
59107
"""
60-
return self.coefficient * np.ones(linear_obj.params)
108+
return self.coefficient * jnp.ones(linear_obj.params)
61109

62-
def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
110+
def regularization_matrix_from(self, linear_obj: LinearObj) -> jnp.ndarray:
63111
"""
64112
Returns the regularization matrix with shape [pixels, pixels].
65113
@@ -73,7 +121,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
73121
The regularization matrix.
74122
"""
75123

76-
return regularization_util.constant_regularization_matrix_from(
124+
return constant_regularization_matrix_from(
77125
coefficient=self.coefficient,
78126
neighbors=linear_obj.neighbors,
79127
neighbors_sizes=linear_obj.neighbors.sizes,

0 commit comments

Comments
 (0)