@@ -86,6 +86,7 @@ def constant_zeroth_regularization_matrix_from(
8686 coefficient : float ,
8787 coefficient_zeroth : float ,
8888 neighbors : np .ndarray ,
89+ neighbors_sizes : np .ndarray [[int ], np .int64 ],
8990) -> np .ndarray :
9091 """
9192 From the pixel-neighbors array, setup the regularization matrix using the instance regularization scheme.
@@ -111,33 +112,28 @@ class in the module ``autoarray.inversion.regularization``.
111112 coefficient of every source pixel is the same.
112113 """
113114 S , P = neighbors .shape
114- reg1 = coefficient ** 2
115- reg0 = coefficient_zeroth ** 2
116-
117- # 1) Flatten (i,j) neighbor‐pairs
118- I = jnp .repeat (jnp .arange (S ), P ) # (S*P,)
119- J = neighbors .reshape (- 1 ) # (S*P,)
120-
121- # 2) Remap “no neighbor” = -1 → OUT = S
122- OUT = S
123- J = jnp .where (J < 0 , OUT , J )
124-
125- # 3) Start on an (S+1)x(S+1) zero canvas
126- M = jnp .zeros ((S + 1 , S + 1 ), dtype = jnp .float32 )
127-
128- # 4) Diagonal baseline: 1e-8 + reg0 for i in [0..S-1]
129- diag_base = jnp .concatenate ([jnp .full ((S ,), 1e-8 + reg0 ), jnp .zeros ((1 ,))])
130- M = M .at [jnp .diag_indices (S + 1 )].add (diag_base )
115+ # as the regularization matrix is S by S, S would be out of bound (any out of bound index would do)
116+ OUT_OF_BOUND_IDX = S
117+ regularization_coefficient = coefficient * coefficient
131118
132- # 5) Scatter the first-order reg1 into diag[i] for each neighbor (i→j):
133- # M[i,i] += reg1
134- M = M .at [I , I ].add (reg1 )
119+ # flatten it for feeding into the matrix as j indices
120+ neighbors = neighbors .flatten ()
121+ # now create the corresponding i indices
122+ I_IDX = jnp .repeat (jnp .arange (S ), P )
123+ # Entries of `-1` in `neighbors` (indicating no neighbor) are replaced with an out-of-bounds index.
124+ # This ensures that JAX can efficiently drop these entries during matrix updates.
125+ neighbors = jnp .where (neighbors == - 1 , OUT_OF_BOUND_IDX , neighbors )
126+ const = (
127+ jnp .diag (1e-8 + regularization_coefficient * neighbors_sizes ).at [I_IDX , neighbors ]
128+ # unique indices should be guranteed by neighbors-spec
129+ .add (- regularization_coefficient , mode = "drop" , unique_indices = True )
130+ )
135131
136- # 6) Scatter the off-diagonals: M[i,j] -= reg1
137- M = M .at [I , J ].add (- reg1 )
132+ reg_coeff = coefficient_zeroth ** 2.0
133+ # Identity matrix scaled by reg_coeff does exactly ∑_i reg_coeff * e_i e_i^T
134+ zeroth = jnp .eye (P ) * reg_coeff
138135
139- # 7) Return only the top‐left S×S block
140- return M [:S , :S ]
136+ return const + zeroth
141137
142138def adaptive_regularization_weights_from (
143139 inner_coefficient : float , outer_coefficient : float , pixel_signals : np .ndarray
0 commit comments