Skip to content

Commit 537b5ef

Browse files
committed
convolve_image now only uses JAX
1 parent 0386bdd commit 537b5ef

2 files changed

Lines changed: 16 additions & 41 deletions

File tree

autoarray/structures/arrays/kernel_2d.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def __init__(
5454
store_native=store_native,
5555
)
5656

57+
if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0:
58+
raise exc.KernelException("Kernel2D Kernel2D must be odd")
59+
5760
if normalize:
5861
self._array = np.divide(self._array, np.sum(self._array))
5962

@@ -500,59 +503,28 @@ def convolve_image(self, image, blurring_image, jax_method="fft"):
500503
kernels that are more than about 5x5. Default is `fft`.
501504
"""
502505

503-
if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0:
504-
raise exc.KernelException("Kernel2D Kernel2D must be odd")
505-
506-
print(type(image.native + blurring_image.native))
507-
print(type(self.native))
508-
509-
convolved_array_2d = scipy.signal.convolve2d((image.native + blurring_image.native)._array, self.native._array, mode="same")
510-
511-
convolved_array_1d = array_2d_util.array_2d_slim_from(
512-
mask_2d=np.array(image.mask),
513-
array_2d_native=convolved_array_2d,
514-
)
515-
516-
return Array2D(values=convolved_array_1d, mask=image.mask)
517-
518-
def convolve_image_jax_from(self, array, blurring_array, method="auto"):
519-
"""
520-
For a given 1D array and blurring array, convolve the two using this convolver.
521-
522-
Parameters
523-
----------
524-
array
525-
1D array of the values which are to be blurred with the convolver's PSF.
526-
blurring_array
527-
1D array of the blurring values which blur into the array after PSF convolution.
528-
jax_method
529-
If JAX is enabled this keyword will indicate what method is used for the PSF
530-
convolution. Can be either `direct` to calculate it in real space or `fft`
531-
to calculated it via a fast Fourier transform. `fft` is typically faster for
532-
kernels that are more than about 5x5. Default is `fft`.
533-
"""
534506
slim_to_native = jnp.nonzero(
535-
jnp.logical_not(self.mask.array), size=array.shape[0]
507+
jnp.logical_not(image.mask.array), size=image.shape[0]
536508
)
537509
slim_to_native_blurring = jnp.nonzero(
538-
jnp.logical_not(self.blurring_mask), size=blurring_array.shape[0]
510+
jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0]
539511
)
540512

541-
expanded_array_native = jnp.zeros(self.mask.shape)
513+
expanded_array_native = jnp.zeros(image.mask.shape)
542514

543515
expanded_array_native = expanded_array_native.at[slim_to_native].set(
544-
array.array
516+
image.array
545517
)
546518
expanded_array_native = expanded_array_native.at[slim_to_native_blurring].set(
547-
blurring_array.array
519+
blurring_image.array
548520
)
549521

550-
kernel = np.array(self.kernel.native.array)
522+
kernel = np.array(self.native.array)
551523

552524
convolve_native = jax.scipy.signal.convolve(
553-
expanded_array_native, kernel, mode="same", method=method
525+
expanded_array_native, kernel, mode="same", method=jax_method
554526
)
555527

556-
convolve_slim = convolve_native[slim_to_native]
528+
convolved_array_1d = convolve_native[slim_to_native]
557529

558-
return convolve_slim
530+
return Array2D(values=convolved_array_1d, mask=image.mask)

test_autoarray/structures/arrays/test_kernel_2d.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,10 @@ def test__convolve_image():
549549
image=masked_image, blurring_image=blurring_image
550550
)
551551

552-
assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4)
552+
print(blurred_masked_image_via_scipy)
553+
print(blurred_masked_im_1)
554+
555+
assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4)
553556

554557

555558
def test__compare_to_full_2d_convolution__no_blurring_image():

0 commit comments

Comments
 (0)