|
1 | 1 | import numpy as np |
2 | 2 | import jax.numpy as jnp |
| 3 | +import jax |
3 | 4 | from typing import Union |
4 | 5 |
|
5 | 6 | from autoconf import conf |
@@ -164,6 +165,14 @@ def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]): |
164 | 165 | # Ensure correct concatenation by making 0 a JAX array |
165 | 166 | self.start_indices = np.concatenate((np.array([0]), self.split_indices[:-1])) |
166 | 167 |
|
| 168 | + # Compute segment ids for each element in the flattened array |
| 169 | + self.segment_ids = np.empty(np.sum(sub_size**2), dtype=np.int32) |
| 170 | + |
| 171 | + for seg_id, (start, end) in enumerate(zip(self.start_indices, self.split_indices)): |
| 172 | + self.segment_ids[start:end] = seg_id |
| 173 | + |
| 174 | + self.segment_ids = jnp.array(self.segment_ids) |
| 175 | + |
167 | 176 | @property |
168 | 177 | def sub_is_uniform(self) -> bool: |
169 | 178 | """ |
@@ -234,18 +243,18 @@ def binned_array_2d_from(self, array: Array2D) -> "Array2D": |
234 | 243 | pass |
235 | 244 |
|
236 | 245 | if self.sub_is_uniform: |
| 246 | + |
237 | 247 | binned_array_2d = array.reshape( |
238 | 248 | self.mask.shape_slim, self.sub_size[0] ** 2 |
239 | 249 | ).mean(axis=1) |
| 250 | + |
240 | 251 | else: |
241 | 252 |
|
242 | 253 | # Compute the group means |
243 | | - binned_array_2d = jnp.array( |
244 | | - [ |
245 | | - array[start:end].mean() |
246 | | - for start, end in zip(self.start_indices, self.split_indices) |
247 | | - ] |
248 | | - ) |
| 254 | + |
| 255 | + sums = jax.ops.segment_sum(array, self.segment_ids, self.mask.pixels_in_mask) |
| 256 | + counts = jax.ops.segment_sum(jnp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask) |
| 257 | + binned_array_2d = sums / counts |
249 | 258 |
|
250 | 259 | return Array2D( |
251 | 260 | values=binned_array_2d, |
|
0 commit comments