Skip to content

Commit 485e076

Browse files
Jammy2211Jammy2211
authored andcommitted
inned supeer sampling array uses jax.ops.segment_sum for fast compile
1 parent 42aea76 commit 485e076

4 files changed

Lines changed: 16 additions & 17 deletions

File tree

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ def adaptive_rectangular_mappings_weights_via_interpolation_from(
236236
# --- Step 3. Transform oversampled grid into index space ---
237237
grid_over_sampled_scaled = (source_plane_data_grid_over_sampled - mu) / scale
238238
grid_over_sampled_transformed = transform(grid_over_sampled_scaled)
239-
grid_over_index = source_grid_size * grid_over_sampled_transformed
240239
grid_over_index = (source_grid_size - 3) * grid_over_sampled_transformed + 1
241240

242241
# --- Step 4. Floor/ceil indices ---

autoarray/inversion/pixelization/mappers/rectangular.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,6 @@ def pix_sub_weights(self) -> PixSubWeights:
9797
dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index`
9898
are equal to 1.0.
9999
"""
100-
101-
# mappings, weights = (
102-
# mapper_util.rectangular_mappings_weights_via_interpolation_from(
103-
# shape_native=self.shape_native,
104-
# source_plane_mesh_grid=self.source_plane_mesh_grid.array,
105-
# source_plane_data_grid=self.source_plane_data_grid.over_sampled,
106-
# )
107-
# )
108-
109100
mappings, weights = (
110101
mapper_util.adaptive_rectangular_mappings_weights_via_interpolation_from(
111102
source_grid_size=self.shape_native[0],

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import jax.numpy as jnp
3+
import jax
34
from typing import Union
45

56
from autoconf import conf
@@ -164,6 +165,14 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]):
164165
# Ensure correct concatenation by making 0 a JAX array
165166
self.start_indices = np.concatenate((np.array([0]), self.split_indices[:-1]))
166167

168+
# Compute segment ids for each element in the flattened array
169+
self.segment_ids = np.empty(np.sum(sub_size**2), dtype=np.int32)
170+
171+
for seg_id, (start, end) in enumerate(zip(self.start_indices, self.split_indices)):
172+
self.segment_ids[start:end] = seg_id
173+
174+
self.segment_ids = jnp.array(self.segment_ids)
175+
167176
@property
168177
def sub_is_uniform(self) -> bool:
169178
"""
@@ -234,18 +243,18 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D":
234243
pass
235244

236245
if self.sub_is_uniform:
246+
237247
binned_array_2d = array.reshape(
238248
self.mask.shape_slim, self.sub_size[0] ** 2
239249
).mean(axis=1)
250+
240251
else:
241252

242253
# Compute the group means
243-
binned_array_2d = jnp.array(
244-
[
245-
array[start:end].mean()
246-
for start, end in zip(self.start_indices, self.split_indices)
247-
]
248-
)
254+
255+
sums = jax.ops.segment_sum(array, self.segment_ids, self.mask.pixels_in_mask)
256+
counts = jax.ops.segment_sum(jnp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask)
257+
binned_array_2d = sums / counts
249258

250259
return Array2D(
251260
values=binned_array_2d,

autoarray/plot/mat_plot/two_d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def _plot_rectangular_mapper(
564564
else:
565565
ax = self.setup_subplot(aspect=aspect_inv)
566566

567-
shape_native = mapper.source_plane_mesh_grid.shape_native
567+
shape_native = mapper.source_plane_mesh_grid.shape_native
568568

569569
if pixel_values is not None:
570570

0 commit comments

Comments
 (0)