Skip to content

Commit a509014

Browse files
Jammy2211Jammy2211
authored andcommitted
fix constant zeroth
1 parent 8a6951a commit a509014

2 files changed

Lines changed: 23 additions & 24 deletions

File tree

autoarray/inversion/regularization/regularization_util.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def constant_zeroth_regularization_matrix_from(
8686
coefficient: float,
8787
coefficient_zeroth: float,
8888
neighbors: np.ndarray,
89+
neighbors_sizes: np.ndarray[[int], np.int64],
8990
) -> np.ndarray:
9091
"""
9192
From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme.
@@ -111,33 +112,28 @@ class in the module ``autoarray.inversion.regularization``.
111112
coefficient of every source pixel is the same.
112113
"""
113114
S, P = neighbors.shape
114-
reg1 = coefficient**2
115-
reg0 = coefficient_zeroth**2
116-
117-
# 1) Flatten (i,j) neighbor‐pairs
118-
I = jnp.repeat(jnp.arange(S), P) # (S*P,)
119-
J = neighbors.reshape(-1) # (S*P,)
120-
121-
# 2) Remap “no neighbor” = -1 → OUT = S
122-
OUT = S
123-
J = jnp.where(J < 0, OUT, J)
124-
125-
# 3) Start on an (S+1)x(S+1) zero canvas
126-
M = jnp.zeros((S+1, S+1), dtype=jnp.float32)
127-
128-
# 4) Diagonal baseline: 1e-8 + reg0 for i in [0..S-1]
129-
diag_base = jnp.concatenate([jnp.full((S,), 1e-8 + reg0), jnp.zeros((1,))])
130-
M = M.at[jnp.diag_indices(S+1)].add(diag_base)
115+
# as the regularization matrix is S by S, S would be out of bound (any out of bound index would do)
116+
OUT_OF_BOUND_IDX = S
117+
regularization_coefficient = coefficient * coefficient
131118

132-
# 5) Scatter the first-order reg1 into diag[i] for each neighbor (i→j):
133-
# M[i,i] += reg1
134-
M = M.at[I, I].add(reg1)
119+
# flatten it for feeding into the matrix as j indices
120+
neighbors = neighbors.flatten()
121+
# now create the corresponding i indices
122+
I_IDX = jnp.repeat(jnp.arange(S), P)
123+
# Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index.
124+
# This ensures that JAX can efficiently drop these entries during matrix updates.
125+
neighbors = jnp.where(neighbors == -1, OUT_OF_BOUND_IDX, neighbors)
126+
const = (
127+
jnp.diag(1e-8 + regularization_coefficient * neighbors_sizes).at[I_IDX, neighbors]
128+
# unique indices should be guranteed by neighbors-spec
129+
.add(-regularization_coefficient, mode="drop", unique_indices=True)
130+
)
135131

136-
# 6) Scatter the off-diagonals: M[i,j] -= reg1
137-
M = M.at[I, J].add(-reg1)
132+
reg_coeff = coefficient_zeroth ** 2.0
133+
# Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T
134+
zeroth = jnp.eye(P) * reg_coeff
138135

139-
# 7) Return only the top‐left S×S block
140-
return M[:S, :S]
136+
return const + zeroth
141137

142138
def adaptive_regularization_weights_from(
143139
inner_coefficient: float, outer_coefficient: float, pixel_signals: np.ndarray

test_autoarray/inversion/regularizations/test_regularization_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,14 @@ def test__constant_regularization_matrix_from():
147147
def test__constant_zeroth_regularization_matrix_from():
148148
neighbors = np.array([[1, 2, -1], [0, -1, -1], [0, -1, -1]])
149149

150+
neighbors_sizes = np.array([2, 1, 1])
151+
150152
regularization_matrix = (
151153
aa.util.regularization.constant_zeroth_regularization_matrix_from(
152154
coefficient=2.0,
153155
coefficient_zeroth=0.5,
154156
neighbors=neighbors,
157+
neighbors_sizes=neighbors_sizes,
155158
)
156159
)
157160

0 commit comments

Comments
 (0)