Skip to content

Commit 44cd415

Browse files
committed
convolve_image_no_blurring
1 parent 27bf06a commit 44cd415

2 files changed

Lines changed: 49 additions & 8 deletions

File tree

autoarray/structures/arrays/kernel_2d.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,4 +527,41 @@ def convolve_image(self, image, blurring_image, jax_method="fft"):
527527

528528
convolved_array_1d = convolve_native[slim_to_native]
529529

530+
return Array2D(values=convolved_array_1d, mask=image.mask)
531+
532+
def convolve_image_no_blurring(self, image, jax_method="fft"):
533+
"""
534+
For a given 1D array and blurring array, convolve the two using this convolver.
535+
536+
Parameters
537+
----------
538+
image
539+
1D array of the values which are to be blurred with the convolver's PSF.
540+
blurring_image
541+
1D array of the blurring values which blur into the array after PSF convolution.
542+
jax_method
543+
If JAX is enabled this keyword will indicate what method is used for the PSF
544+
convolution. Can be either `direct` to calculate it in real space or `fft`
545+
to calculated it via a fast Fourier transform. `fft` is typically faster for
546+
kernels that are more than about 5x5. Default is `fft`.
547+
"""
548+
549+
slim_to_native = jnp.nonzero(
550+
jnp.logical_not(image.mask.array), size=image.shape[0]
551+
)
552+
553+
expanded_array_native = jnp.zeros(image.mask.shape)
554+
555+
expanded_array_native = expanded_array_native.at[slim_to_native].set(
556+
image.array
557+
)
558+
559+
kernel = np.array(self.native.array)
560+
561+
convolve_native = jax.scipy.signal.convolve(
562+
expanded_array_native, kernel, mode="same", method=jax_method
563+
)
564+
565+
convolved_array_1d = convolve_native[slim_to_native]
566+
530567
return Array2D(values=convolved_array_1d, mask=image.mask)

test_autoarray/structures/arrays/test_kernel_2d.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -398,22 +398,23 @@ def test__convolve_image():
398398
assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4)
399399

400400

401-
def test__compare_to_full_2d_convolution__no_blurring_image():
401+
def test__convolve_image_no_blurring():
402402
# Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array.
403403

404-
import scipy.signal
405-
406404
mask = aa.Mask2D.circular(
407405
shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0
408406
)
409-
kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0)
410-
image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0)
407+
408+
import scipy.signal
409+
410+
kernel = np.arange(49).reshape(7, 7)
411+
image = np.arange(900).reshape(30, 30)
411412

412413
blurring_mask = mask.derive_mask.blurring_from(
413-
kernel_shape_native=kernel.shape_native
414+
kernel_shape_native=kernel.shape
414415
)
415416
blurred_image_via_scipy = scipy.signal.convolve2d(
416-
image.native * blurring_mask, kernel.native, mode="same"
417+
image * blurring_mask, kernel, mode="same"
417418
)
418419
blurred_image_via_scipy = aa.Array2D.no_mask(
419420
values=blurred_image_via_scipy, pixel_scales=1.0
@@ -424,11 +425,14 @@ def test__compare_to_full_2d_convolution__no_blurring_image():
424425

425426
# Now reproduce this data using the frame convolver_image
426427

428+
kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0)
429+
image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0)
430+
427431
masked_image = aa.Array2D(values=image.native, mask=mask)
428432

429433
blurred_masked_im_1 = kernel.convolve_image_no_blurring(image=masked_image)
430434

431-
assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1, 1e-4)
435+
assert blurred_masked_image_via_scipy == pytest.approx(blurred_masked_im_1.array, 1e-4)
432436

433437

434438
def test__convolve_mapping_matrix():

0 commit comments

Comments
 (0)