Skip to content

Commit 5677589

Browse files
Jammy2211Jammy2211
authored andcommitted
update data_weight_total_for_pix_from
1 parent a92d828 commit 5677589

1 file changed

Lines changed: 21 additions & 19 deletions

File tree

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -726,32 +726,34 @@ def mapped_to_source_via_mapping_matrix_from(
726726
return mapped_to_source
727727

728728

729-
@numba_util.jit()
730729
def data_weight_total_for_pix_from(
731-
pix_indexes_for_sub_slim_index: np.ndarray,
732-
pix_weights_for_sub_slim_index: np.ndarray,
730+
pix_indexes_for_sub_slim_index: np.ndarray, # shape (M, B)
731+
pix_weights_for_sub_slim_index: np.ndarray, # shape (M, B)
733732
pixels: int,
734733
) -> np.ndarray:
735734
"""
736-
Returns the total weight of every pixelization pixel, which is the sum of the weights of all data-points that
737-
map to that pixel.
735+
Returns the total weight of every pixelization pixel, which is the sum of
736+
the weights of all data‐points (sub‐pixels) that map to that pixel.
738737
739738
Parameters
740739
----------
741-
pix_indexes_for_sub_slim_index
742-
The mappings from a data sub-pixel index to a pixelization pixel index.
743-
pix_weights_for_sub_slim_index
744-
The weights of the mappings of every data sub-pixel and pixelization pixel.
745-
pixels
746-
The number of pixels in the pixelization.
747-
"""
740+
pix_indexes_for_sub_slim_index : np.ndarray, shape (M, B), int
741+
For each of M sub‐slim indexes, the B pixelization‐pixel indices it maps to.
742+
pix_weights_for_sub_slim_index : np.ndarray, shape (M, B), float
743+
For each of those mappings, the corresponding interpolation weight.
744+
pixels : int
745+
The total number of pixelization pixels N.
748746
749-
pix_weight_total = np.zeros(pixels)
747+
Returns
748+
-------
749+
np.ndarray, shape (N,)
750+
The per‐pixel total weight: for each j in [0..N-1], the sum of all
751+
pix_weights_for_sub_slim_index[i,k] such that pix_indexes_for_sub_slim_index[i,k] == j.
752+
"""
753+
# Flatten both arrays into 1D
754+
flat_idxs = pix_indexes_for_sub_slim_index.ravel()
755+
flat_weights = pix_weights_for_sub_slim_index.ravel()
750756

751-
for slim_index, pix_indexes in enumerate(pix_indexes_for_sub_slim_index):
752-
for pix_index, weight in zip(
753-
pix_indexes, pix_weights_for_sub_slim_index[slim_index]
754-
):
755-
pix_weight_total[int(pix_index)] += weight
757+
# Use bincount to sum weights at each index, ensuring length = pixels
758+
return np.bincount(flat_idxs, weights=flat_weights, minlength=pixels)
756759

757-
return pix_weight_total

0 commit comments

Comments
 (0)