@@ -382,52 +382,95 @@ def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
382382 @cached_property
383383 def zeroed_ids_to_keep (self ):
384384 """
385- Return the **positive global indices** of linear objects to keep in the inversion,
386- given **mesh-local positive pixel indices** to zero.
385+ Return the **positive global indices** of linear parameters that should be
386+ kept (solved for) in the inversion, accounting for **zeroed pixel indices**
387+ from one or more mappers.
387388
388- Assumes the full linear system parameter vector is laid out as:
389+ ---------------------------------------------------------------------------
390+ Parameter vector layout
391+ ---------------------------------------------------------------------------
392+ This method assumes the full linear parameter vector is ordered as:
389393
390- [ non-pixel linear objects ][ pixel block ]
394+ [ non-pixel linear objects ][ mapper_0 pixels ][ mapper_1 pixels ] ... [ mapper_M pixels ]
391395
392396 where:
393- - the pixel block occupies the final `mesh.pixels` entries of the full vector
394- - `mesh.zeroed_pixels` contains **positive** indices in the pixel block's own
395- indexing (0 is top-left for rectangular meshes, increasing row-major)
396- - all non-pixel linear objects are always kept
397397
398- This implementation is backend-friendly (NumPy / JAX via `self._xp`) and returns
399- indices suitable for advanced indexing of `data_vector` and curvature matrices.
398+ - *Non-pixel linear objects* include quantities such as analytic light
399+ profiles, regularization amplitudes, etc.
400+ - Each mapper contributes a contiguous block of pixel-based linear parameters.
401+ - The concatenated pixel blocks occupy the **final** entries of the parameter
402+ vector, with total length:
403+
404+ total_pixels = sum(mapper.mesh.pixels for mapper in mappers)
405+
406+ ---------------------------------------------------------------------------
407+ Zeroed pixel convention
408+ ---------------------------------------------------------------------------
409+ For each mapper:
410+
411+ - `mapper.mesh.zeroed_pixels` must be a 1D array of **positive, mesh-local**
412+ pixel indices in the range `[0, mapper.mesh.pixels - 1]`.
413+ - These indices identify pixels that should be **excluded** from the linear
414+ solve (e.g. edge pixels, masked regions, or padding pixels).
415+ - Indexing is defined purely within the mapper’s own pixelization (e.g.
416+ row-major flattening for rectangular meshes).
417+
418+ This method converts all mesh-local zeroed pixel indices into **global
419+ parameter indices**, correctly offsetting for:
420+ - the presence of non-pixel linear objects at the start of the vector
421+ - the cumulative pixel counts of preceding mappers
422+
423+ ---------------------------------------------------------------------------
424+ Backend and implementation details
425+ ---------------------------------------------------------------------------
426+ - The implementation is backend-agnostic and supports both NumPy and JAX via
427+ `self._xp`.
428+ - The returned indices are **positive global indices**, suitable for advanced
429+ indexing of:
430+ - `self.data_vector`
431+ - `self.curvature_reg_matrix`
432+ - When using JAX, this method avoids backend-incompatible operations and
433+ preserves JIT compatibility under the same constraints as the rest of the
434+ inversion pipeline.
400435
401436 Returns
402437 -------
403438 array-like
404- 1D array of **positive** indices to keep, sorted ascending.
439+ A 1D array of **positive global indices**, sorted in ascending order,
440+ corresponding to linear parameters that should be kept in the inversion.
405441 """
406442 xp = self ._xp
407443
408444 mapper_list = self .cls_list_from (cls = Mapper )
409- mesh = mapper_list [0 ].mesh
410445
411446 n_total = int (self .total_params )
412- n_pixels = int (mesh .pixels )
413-
414- # Pixel block starts at this global index
415- pixel_start = n_total - n_pixels
416-
417- # Mesh-local positive pixel indices to zero (e.g. [0, 1, 2, ...] for edges)
418- zeros_local = xp .asarray (mesh .zeroed_pixels )
419-
420- # Convert to global positive indices
421- zeros_global = pixel_start + zeros_local
447+ pixels_per_mapper = [int (m .mesh .pixels ) for m in mapper_list ]
448+ total_pixels = int (sum (pixels_per_mapper ))
449+
450+ # Global start index of the concatenated pixel block
451+ pixel_start = n_total - total_pixels
452+
453+ # Build global indices to zero across all mappers
454+ zeros_global_list = []
455+ offset = 0
456+ for mapper , n_pix in zip (mapper_list , pixels_per_mapper ):
457+ zeros_local = xp .asarray (mapper .mesh .zeroed_pixels , dtype = xp .int32 )
458+ zeros_global_list .append (pixel_start + offset + zeros_local )
459+ offset += n_pix
460+
461+ zeros_global = (
462+ xp .concatenate (zeros_global_list ) if len (zeros_global_list ) > 0 else xp .asarray ([], dtype = xp .int32 )
463+ )
422464
423- # Keep mask over full parameter vector
424465 keep = xp .ones ((n_total ,), dtype = bool )
425466
426467 if hasattr (keep , "at" ):
468+ # JAX path
427469 keep = keep .at [zeros_global ].set (False )
428470 keep_ids = xp .nonzero (keep , size = n_total )[0 ]
429471 keep_ids = keep_ids [: keep .sum ()]
430472 else :
473+ # NumPy path
431474 keep [zeros_global ] = False
432475 keep_ids = xp .nonzero (keep )[0 ]
433476
0 commit comments