@@ -379,6 +379,60 @@ def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
379379 # Zero rows and columns in the matrix we want to ignore
380380 return self .curvature_reg_matrix [ids_to_keep ][:, ids_to_keep ]
381381
382+ @cached_property
383+ def zeroed_ids_to_keep (self ):
384+ """
385+ Return the **positive global indices** of linear objects to keep in the inversion,
386+ given **mesh-local positive pixel indices** to zero.
387+
388+ Assumes the full linear system parameter vector is laid out as:
389+
390+ [ non-pixel linear objects ][ pixel block ]
391+
392+ where:
393+ - the pixel block occupies the final `mesh.pixels` entries of the full vector
394+ - `mesh.zeroed_pixels` contains **positive** indices in the pixel block's own
395+ indexing (0 is top-left for rectangular meshes, increasing row-major)
396+ - all non-pixel linear objects are always kept
397+
398+ This implementation is backend-friendly (NumPy / JAX via `self._xp`) and returns
399+ indices suitable for advanced indexing of `data_vector` and curvature matrices.
400+
401+ Returns
402+ -------
403+ array-like
404+ 1D array of **positive** indices to keep, sorted ascending.
405+ """
406+ xp = self ._xp
407+
408+ mapper_list = self .cls_list_from (cls = Mapper )
409+ mesh = mapper_list [0 ].mesh
410+
411+ n_total = int (self .total_params )
412+ n_pixels = int (mesh .pixels )
413+
414+ # Pixel block starts at this global index
415+ pixel_start = n_total - n_pixels
416+
417+ # Mesh-local positive pixel indices to zero (e.g. [0, 1, 2, ...] for edges)
418+ zeros_local = xp .asarray (mesh .zeroed_pixels )
419+
420+ # Convert to global positive indices
421+ zeros_global = pixel_start + zeros_local
422+
423+ # Keep mask over full parameter vector
424+ keep = xp .ones ((n_total ,), dtype = bool )
425+
426+ if hasattr (keep , "at" ):
427+ keep = keep .at [zeros_global ].set (False )
428+ keep_ids = xp .nonzero (keep , size = n_total )[0 ]
429+ keep_ids = keep_ids [: keep .sum ()]
430+ else :
431+ keep [zeros_global ] = False
432+ keep_ids = xp .nonzero (keep )[0 ]
433+
434+ return keep_ids
435+
382436 @cached_property
383437 def reconstruction (self ) -> np .ndarray :
384438 """
@@ -400,16 +454,11 @@ def reconstruction(self) -> np.ndarray:
400454
401455 if self .settings .use_edge_zeroed_pixels and self .has (cls = Mapper ):
402456
403- mapper_list = self .cls_list_from (cls = Mapper )
404-
405- # ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads.
406- ids_to_keep = mapper_list [0 ].mesh .zeroed_pixels_to_keep
407-
408457 # Use advanced indexing to select rows/columns
409- data_vector = self .data_vector [ids_to_keep ]
410- curvature_reg_matrix = self .curvature_reg_matrix [ids_to_keep ][
411- :, ids_to_keep
412- ]
458+ data_vector = self .data_vector [self . zeroed_ids_to_keep ]
459+ curvature_reg_matrix = self .curvature_reg_matrix [
460+ self . zeroed_ids_to_keep
461+ ][:, self . zeroed_ids_to_keep ]
413462
414463 # Perform reconstruction via fnnls
415464 reconstruction_partial = (
@@ -426,11 +475,11 @@ def reconstruction(self) -> np.ndarray:
426475
427476 # Scatter the partial solution back to the full shape
428477 if self ._xp .__name__ .startswith ("jax" ):
429- reconstruction = reconstruction .at [ids_to_keep ].set (
478+ reconstruction = reconstruction .at [self . zeroed_ids_to_keep ].set (
430479 reconstruction_partial
431480 )
432481 else :
433- reconstruction [ids_to_keep ] = reconstruction_partial
482+ reconstruction [self . zeroed_ids_to_keep ] = reconstruction_partial
434483
435484 return reconstruction
436485
0 commit comments