Skip to content

Commit 08165cd

Browse files
Jammy2211Jammy2211
authored andcommitted
over sampler xp simplified
1 parent e99bfaa commit 08165cd

3 files changed

Lines changed: 8 additions & 13 deletions

File tree

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def mapping_matrix_from(
516516
flat_pixidx = xp.where(flat_pixidx < 0, OUT, flat_pixidx)
517517

518518
# 5) Multiply by sub_fraction of the slim row
519-
flat_frac = sub_fraction[flat_parent] # (M_sub*B,)
519+
flat_frac = xp.take(sub_fraction, flat_parent, axis=0) # (M_sub*B,)
520520
flat_contrib = flat_w * flat_frac # (M_sub*B,)
521521

522522
# 6) Scatter into (M × (S+1)), summing duplicates

autoarray/operators/over_sampling/decorator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def over_sample(func):
3131
def wrapper(
3232
obj: object,
3333
grid: Union[np.ndarray, Grid2D, Grid2DIrregular, Grid1D],
34+
xp=np,
3435
*args,
3536
**kwargs,
3637
) -> Union[np.ndarray, Array1D, Array2D, ArrayIrregular, List]:
@@ -49,12 +50,12 @@ def wrapper(
4950
"""
5051

5152
if isinstance(grid, Grid2DIrregular) or isinstance(grid, Grid1D):
52-
return func(obj=obj, grid=grid, *args, **kwargs)
53+
return func(obj=obj, grid=grid, xp=xp, *args, **kwargs)
5354

5455
if obj is not None:
55-
values = func(obj, grid.over_sampled, *args, **kwargs)
56+
values = func(obj, grid.over_sampled, xp, *args, **kwargs)
5657
else:
57-
values = func(grid.over_sampled, *args, **kwargs)
58+
values = func(grid.over_sampled, xp, *args, **kwargs)
5859

5960
return grid.over_sampler.binned_array_2d_from(array=values)
6061

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
@register_pytree_node_class
1616
class OverSampler:
17-
def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D], xp=np):
17+
def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]):
1818
"""
1919
Over samples grid calculations using a uniform sub-grid.
2020
@@ -149,7 +149,7 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D], xp=np):
149149
self.sub_total = int(np.sum(self.sub_size**2))
150150
self.sub_length = self.sub_size**self.mask.dimensions
151151
self.sub_fraction = Array2D(
152-
values=xp.array(1.0 / self.sub_length.array), mask=self.mask
152+
values=1.0 / self.sub_length.array, mask=self.mask
153153
)
154154

155155
# Used for JAX based adaptive over sampling.
@@ -171,10 +171,6 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D], xp=np):
171171
):
172172
self.segment_ids[start:end] = seg_id
173173

174-
self.segment_ids = xp.array(self.segment_ids)
175-
176-
self.xp = xp
177-
178174
@property
179175
def sub_is_uniform(self) -> bool:
180176
"""
@@ -260,12 +256,10 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D":
260256
array, self.segment_ids, self.mask.pixels_in_mask
261257
)
262258
counts = jax.ops.segment_sum(
263-
self.xp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask
259+
np.ones_like(array), self.segment_ids, self.mask.pixels_in_mask
264260
)
265261
binned_array_2d = sums / counts
266262

267-
binned_array_2d = self.xp.array(binned_array_2d)
268-
269263
return Array2D(
270264
values=binned_array_2d,
271265
mask=self.mask,

0 commit comments

Comments
 (0)