@@ -439,40 +439,44 @@ def zeroed_ids_to_keep(self):
439439 A 1D array of **positive global indices**, sorted in ascending order,
440440 corresponding to linear parameters that should be kept in the inversion.
441441 """
442- xp = self ._xp
443442
444443 mapper_list = self .cls_list_from (cls = Mapper )
445444
446445 n_total = int (self .total_params )
446+
447447 pixels_per_mapper = [int (m .mesh .pixels ) for m in mapper_list ]
448448 total_pixels = int (sum (pixels_per_mapper ))
449449
450- # Global start index of the concatenated pixel block
450+ # Global start index of concatenated pixel block
451451 pixel_start = n_total - total_pixels
452452
453- # Build global indices to zero across all mappers
453+ # Total number of zeroed pixels across all mappers (Python int => static)
454+ total_zeroed = int (sum (len (m .mesh .zeroed_pixels ) for m in mapper_list ))
455+ n_keep = int (n_total - total_zeroed )
456+
457+ # Build global indices-to-zero across all mappers
454458 zeros_global_list = []
455459 offset = 0
456- for mapper , n_pix in zip (mapper_list , pixels_per_mapper ):
457- zeros_local = xp . asarray (mapper .mesh .zeroed_pixels , dtype = xp .int32 )
460+ for m , n_pix in zip (mapper_list , pixels_per_mapper ):
461+ zeros_local = self . _xp . asarray (m .mesh .zeroed_pixels , dtype = self . _xp .int32 )
458462 zeros_global_list .append (pixel_start + offset + zeros_local )
459463 offset += n_pix
460464
461465 zeros_global = (
462- xp .concatenate (zeros_global_list ) if len (zeros_global_list ) > 0 else xp .asarray ([], dtype = xp .int32 )
466+ self ._xp .concatenate (zeros_global_list )
467+ if len (zeros_global_list ) > 0
468+ else self ._xp .asarray ([], dtype = self ._xp .int32 )
463469 )
464470
465- keep = xp .ones ((n_total ,), dtype = bool )
471+ keep = self . _xp .ones ((n_total ,), dtype = bool )
466472
467- if hasattr (keep , "at" ):
468- # JAX path
469- keep = keep .at [zeros_global ].set (False )
470- keep_ids = xp .nonzero (keep , size = n_total )[0 ]
471- keep_ids = keep_ids [: keep .sum ()]
472- else :
473- # NumPy path
473+ if self ._xp is np :
474474 keep [zeros_global ] = False
475- keep_ids = xp .nonzero (keep )[0 ]
475+ keep_ids = self ._xp .nonzero (keep )[0 ]
476+
477+ else :
478+ keep = keep .at [zeros_global ].set (False )
479+ keep_ids = self ._xp .nonzero (keep , size = n_keep )[0 ]
476480
477481 return keep_ids
478482
0 commit comments