Skip to content

Commit cff3930

Browse files
Jammy2211Jammy2211
authored andcommitted
multi mapper zeroing support
1 parent 39d2dd5 commit cff3930

1 file changed

Lines changed: 19 additions & 15 deletions

File tree

autoarray/inversion/inversion/abstract.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -439,40 +439,44 @@ def zeroed_ids_to_keep(self):
439439
A 1D array of **positive global indices**, sorted in ascending order,
440440
corresponding to linear parameters that should be kept in the inversion.
441441
"""
442-
xp = self._xp
443442

444443
mapper_list = self.cls_list_from(cls=Mapper)
445444

446445
n_total = int(self.total_params)
446+
447447
pixels_per_mapper = [int(m.mesh.pixels) for m in mapper_list]
448448
total_pixels = int(sum(pixels_per_mapper))
449449

450-
# Global start index of the concatenated pixel block
450+
# Global start index of concatenated pixel block
451451
pixel_start = n_total - total_pixels
452452

453-
# Build global indices to zero across all mappers
453+
# Total number of zeroed pixels across all mappers (Python int => static)
454+
total_zeroed = int(sum(len(m.mesh.zeroed_pixels) for m in mapper_list))
455+
n_keep = int(n_total - total_zeroed)
456+
457+
# Build global indices-to-zero across all mappers
454458
zeros_global_list = []
455459
offset = 0
456-
for mapper, n_pix in zip(mapper_list, pixels_per_mapper):
457-
zeros_local = xp.asarray(mapper.mesh.zeroed_pixels, dtype=xp.int32)
460+
for m, n_pix in zip(mapper_list, pixels_per_mapper):
461+
zeros_local = self._xp.asarray(m.mesh.zeroed_pixels, dtype=self._xp.int32)
458462
zeros_global_list.append(pixel_start + offset + zeros_local)
459463
offset += n_pix
460464

461465
zeros_global = (
462-
xp.concatenate(zeros_global_list) if len(zeros_global_list) > 0 else xp.asarray([], dtype=xp.int32)
466+
self._xp.concatenate(zeros_global_list)
467+
if len(zeros_global_list) > 0
468+
else self._xp.asarray([], dtype=self._xp.int32)
463469
)
464470

465-
keep = xp.ones((n_total,), dtype=bool)
471+
keep = self._xp.ones((n_total,), dtype=bool)
466472

467-
if hasattr(keep, "at"):
468-
# JAX path
469-
keep = keep.at[zeros_global].set(False)
470-
keep_ids = xp.nonzero(keep, size=n_total)[0]
471-
keep_ids = keep_ids[: keep.sum()]
472-
else:
473-
# NumPy path
473+
if self._xp is np:
474474
keep[zeros_global] = False
475-
keep_ids = xp.nonzero(keep)[0]
475+
keep_ids = self._xp.nonzero(keep)[0]
476+
477+
else:
478+
keep = keep.at[zeros_global].set(False)
479+
keep_ids = self._xp.nonzero(keep, size=n_keep)[0]
476480

477481
return keep_ids
478482

0 commit comments

Comments
 (0)