Skip to content

Commit 6fa35a6

Browse files
committed
simplify jax_convolve
1 parent 5762fcf commit 6fa35a6

1 file changed

Lines changed: 12 additions & 5 deletions

File tree

autoarray/structures/arrays/kernel_2d.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -484,22 +484,29 @@ def convolved_array_from(self, array: Array2D) -> Array2D:
484484
return Array2D(values=convolved_array_1d, mask=array_2d.mask)
485485

486486
def jax_convolve(self, image, blurring_image, method="auto"):
487-
slim_to_2D_index_image = jnp.nonzero(
487+
488+
slim_to_native = jnp.nonzero(
488489
jnp.logical_not(self.mask.array), size=image.shape[0]
489490
)
490-
slim_to_2D_index_blurring = jnp.nonzero(
491+
slim_to_native_blurring = jnp.nonzero(
491492
jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0]
492493
)
494+
493495
expanded_image_native = jnp.zeros(self.mask.shape)
494-
expanded_image_native = expanded_image_native.at[slim_to_2D_index_image].set(
496+
497+
expanded_image_native = expanded_image_native.at[slim_to_native].set(
495498
image.array
496499
)
497-
expanded_image_native = expanded_image_native.at[slim_to_2D_index_blurring].set(
500+
expanded_image_native = expanded_image_native.at[slim_to_native_blurring].set(
498501
blurring_image.array
499502
)
503+
500504
kernel = np.array(self.kernel.native.array)
505+
501506
convolve_native = jax.scipy.signal.convolve(
502507
expanded_image_native, kernel, mode="same", method=method
503508
)
504-
convolve_slim = convolve_native[slim_to_2D_index_image]
509+
510+
convolve_slim = convolve_native[slim_to_native]
511+
505512
return convolve_slim

0 commit comments

Comments
 (0)