@@ -498,7 +498,6 @@ def remove_bad_entries_voronoi_nn(
498498 return pix_weights_for_sub_slim_index , pix_indexes_for_sub_slim_index
499499
500500
501- @numba_util .jit ()
502501def adaptive_pixel_signals_from (
503502 pixels : int ,
504503 pixel_weights : np .ndarray ,
@@ -536,30 +535,43 @@ def adaptive_pixel_signals_from(
536535 The image of the galaxy which is used to compute the weigghted pixel signals.
537536 """
538537
539- pixel_signals = np .zeros ((pixels ,))
540- pixel_sizes = np .zeros ((pixels ,))
538+ M_sub , B = pix_indexes_for_sub_slim_index .shape
541539
542- for sub_slim_index in range (len (pix_indexes_for_sub_slim_index )):
543- vertices_indexes = pix_indexes_for_sub_slim_index [sub_slim_index ]
540+ # 1) Flatten the per‐mapping tables:
541+ flat_pixidx = pix_indexes_for_sub_slim_index .reshape (- 1 ) # (M_sub*B,)
542+ flat_weights = pixel_weights .reshape (- 1 ) # (M_sub*B,)
544543
545- mask_1d_index = slim_index_for_sub_slim_index [sub_slim_index ]
544+ # 2) Build a matching “parent‐slim” index for each flattened entry:
545+ I_sub = jnp .repeat (jnp .arange (M_sub ), B ) # (M_sub*B,)
546546
547- pix_size_tem = pix_size_for_sub_slim_index [sub_slim_index ]
547+ # 3) Mask out any k >= pix_size_for_sub_slim_index[i]
548+ valid = (I_sub < 0 ) # dummy to get shape
549+ # better:
550+ valid = (jnp .arange (B )[None , :] < pix_size_for_sub_slim_index [:, None ]).reshape (- 1 )
548551
549- if pix_size_tem > 1 :
550- pixel_signals [vertices_indexes [:pix_size_tem ]] += (
551- adapt_data [mask_1d_index ] * pixel_weights [sub_slim_index ]
552- )
553- pixel_sizes [vertices_indexes ] += 1
554- else :
555- pixel_signals [vertices_indexes [0 ]] += adapt_data [mask_1d_index ]
556- pixel_sizes [vertices_indexes [0 ]] += 1
552+ flat_weights = jnp .where (valid , flat_weights , 0.0 )
553+ flat_pixidx = jnp .where (valid , flat_pixidx , pixels ) # send invalid indices to an out-of-bounds slot
554+
555+ # 4) Look up data & multiply by mapping weights:
556+ flat_data_vals = adapt_data [slim_index_for_sub_slim_index ][I_sub ] # (M_sub*B,)
557+ flat_contrib = flat_data_vals * flat_weights # (M_sub*B,)
558+
559+ # 5) Scatter‐add into signal sums and counts:
560+ pixel_signals = jnp .zeros ((pixels + 1 ,)).at [flat_pixidx ].add (flat_contrib )
561+ pixel_counts = jnp .zeros ((pixels + 1 ,)).at [flat_pixidx ].add (valid .astype (float ))
562+
563+ # 6) Drop the extra “out-of-bounds” slot:
564+ pixel_signals = pixel_signals [:pixels ]
565+ pixel_counts = pixel_counts [:pixels ]
557566
558- pixel_sizes [pixel_sizes == 0 ] = 1
559- pixel_signals /= pixel_sizes
560- pixel_signals /= np .max (pixel_signals )
567+ # 7) Normalize
568+ pixel_counts = jnp .where (pixel_counts > 0 , pixel_counts , 1.0 )
569+ pixel_signals = pixel_signals / pixel_counts
570+ max_sig = jnp .max (pixel_signals )
571+ pixel_signals = jnp .where (max_sig > 0 , pixel_signals / max_sig , pixel_signals )
561572
562- return pixel_signals ** signal_scale
573+ # 8) Exponentiate
574+ return pixel_signals ** signal_scale
563575
564576
565577def mapping_matrix_from (
0 commit comments