Skip to content

Commit 39d2dd5

Browse files
Jammy2211Jammy2211
authored andcommitted
multiple mappers now supported
1 parent 446e5bc commit 39d2dd5

1 file changed

Lines changed: 66 additions & 23 deletions

File tree

autoarray/inversion/inversion/abstract.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -382,52 +382,95 @@ def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
382382
@cached_property
383383
def zeroed_ids_to_keep(self):
384384
"""
385-
Return the **positive global indices** of linear objects to keep in the inversion,
386-
given **mesh-local positive pixel indices** to zero.
385+
Return the **positive global indices** of linear parameters that should be
386+
kept (solved for) in the inversion, accounting for **zeroed pixel indices**
387+
from one or more mappers.
387388
388-
Assumes the full linear system parameter vector is laid out as:
389+
---------------------------------------------------------------------------
390+
Parameter vector layout
391+
---------------------------------------------------------------------------
392+
This method assumes the full linear parameter vector is ordered as:
389393
390-
[ non-pixel linear objects ][ pixel block ]
394+
[ non-pixel linear objects ][ mapper_0 pixels ][ mapper_1 pixels ] ... [ mapper_M pixels ]
391395
392396
where:
393-
- the pixel block occupies the final `mesh.pixels` entries of the full vector
394-
- `mesh.zeroed_pixels` contains **positive** indices in the pixel block's own
395-
indexing (0 is top-left for rectangular meshes, increasing row-major)
396-
- all non-pixel linear objects are always kept
397397
398-
This implementation is backend-friendly (NumPy / JAX via `self._xp`) and returns
399-
indices suitable for advanced indexing of `data_vector` and curvature matrices.
398+
- *Non-pixel linear objects* include quantities such as analytic light
399+
profiles, regularization amplitudes, etc.
400+
- Each mapper contributes a contiguous block of pixel-based linear parameters.
401+
- The concatenated pixel blocks occupy the **final** entries of the parameter
402+
vector, with total length:
403+
404+
total_pixels = sum(mapper.mesh.pixels for mapper in mappers)
405+
406+
---------------------------------------------------------------------------
407+
Zeroed pixel convention
408+
---------------------------------------------------------------------------
409+
For each mapper:
410+
411+
- `mapper.mesh.zeroed_pixels` must be a 1D array of **positive, mesh-local**
412+
pixel indices in the range `[0, mapper.mesh.pixels - 1]`.
413+
- These indices identify pixels that should be **excluded** from the linear
414+
solve (e.g. edge pixels, masked regions, or padding pixels).
415+
- Indexing is defined purely within the mapper’s own pixelization (e.g.
416+
row-major flattening for rectangular meshes).
417+
418+
This method converts all mesh-local zeroed pixel indices into **global
419+
parameter indices**, correctly offsetting for:
420+
- the presence of non-pixel linear objects at the start of the vector
421+
- the cumulative pixel counts of preceding mappers
422+
423+
---------------------------------------------------------------------------
424+
Backend and implementation details
425+
---------------------------------------------------------------------------
426+
- The implementation is backend-agnostic and supports both NumPy and JAX via
427+
`self._xp`.
428+
- The returned indices are **positive global indices**, suitable for advanced
429+
indexing of:
430+
- `self.data_vector`
431+
- `self.curvature_reg_matrix`
432+
- When using JAX, this method avoids backend-incompatible operations and
433+
preserves JIT compatibility under the same constraints as the rest of the
434+
inversion pipeline.
400435
401436
Returns
402437
-------
403438
array-like
404-
1D array of **positive** indices to keep, sorted ascending.
439+
A 1D array of **positive global indices**, sorted in ascending order,
440+
corresponding to linear parameters that should be kept in the inversion.
405441
"""
406442
xp = self._xp
407443

408444
mapper_list = self.cls_list_from(cls=Mapper)
409-
mesh = mapper_list[0].mesh
410445

411446
n_total = int(self.total_params)
412-
n_pixels = int(mesh.pixels)
413-
414-
# Pixel block starts at this global index
415-
pixel_start = n_total - n_pixels
416-
417-
# Mesh-local positive pixel indices to zero (e.g. [0, 1, 2, ...] for edges)
418-
zeros_local = xp.asarray(mesh.zeroed_pixels)
419-
420-
# Convert to global positive indices
421-
zeros_global = pixel_start + zeros_local
447+
pixels_per_mapper = [int(m.mesh.pixels) for m in mapper_list]
448+
total_pixels = int(sum(pixels_per_mapper))
449+
450+
# Global start index of the concatenated pixel block
451+
pixel_start = n_total - total_pixels
452+
453+
# Build global indices to zero across all mappers
454+
zeros_global_list = []
455+
offset = 0
456+
for mapper, n_pix in zip(mapper_list, pixels_per_mapper):
457+
zeros_local = xp.asarray(mapper.mesh.zeroed_pixels, dtype=xp.int32)
458+
zeros_global_list.append(pixel_start + offset + zeros_local)
459+
offset += n_pix
460+
461+
zeros_global = (
462+
xp.concatenate(zeros_global_list) if len(zeros_global_list) > 0 else xp.asarray([], dtype=xp.int32)
463+
)
422464

423-
# Keep mask over full parameter vector
424465
keep = xp.ones((n_total,), dtype=bool)
425466

426467
if hasattr(keep, "at"):
468+
# JAX path
427469
keep = keep.at[zeros_global].set(False)
428470
keep_ids = xp.nonzero(keep, size=n_total)[0]
429471
keep_ids = keep_ids[: keep.sum()]
430472
else:
473+
# NumPy path
431474
keep[zeros_global] = False
432475
keep_ids = xp.nonzero(keep)[0]
433476

0 commit comments

Comments
 (0)