Skip to content

Commit 0386bdd

Browse files
committed
test__convolve_image
1 parent 6fa35a6 commit 0386bdd

2 files changed

Lines changed: 98 additions & 20 deletions

File tree

autoarray/structures/arrays/kernel_2d.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from autoconf.fitsable import header_obj_from
1010

11-
from autoarray.mask.mask_2d import Mask2D
1211
from autoarray.structures.arrays.uniform_2d import AbstractArray2D
1312
from autoarray.structures.arrays.uniform_2d import Array2D
1413
from autoarray.structures.grids.uniform_2d import Grid2D
@@ -17,6 +16,7 @@
1716
from autoarray import exc
1817
from autoarray import type as ty
1918
from autoarray.structures.arrays import array_2d_util
19+
from autoarray.mask.mask_2d import mask_2d_util
2020

2121

2222
class Kernel2D(AbstractArray2D):
@@ -483,28 +483,74 @@ def convolved_array_from(self, array: Array2D) -> Array2D:
483483

484484
return Array2D(values=convolved_array_1d, mask=array_2d.mask)
485485

486-
def jax_convolve(self, image, blurring_image, method="auto"):
486+
def convolve_image(self, image, blurring_image, jax_method="fft"):
487+
"""
488+
For a given 1D array and blurring array, convolve the two using this convolver.
489+
490+
Parameters
491+
----------
492+
image
493+
1D array of the values which are to be blurred with the convolver's PSF.
494+
blurring_image
495+
1D array of the blurring values which blur into the array after PSF convolution.
496+
jax_method
497+
If JAX is enabled this keyword will indicate what method is used for the PSF
498+
convolution. Can be either `direct` to calculate it in real space or `fft`
499+
to calculated it via a fast Fourier transform. `fft` is typically faster for
500+
kernels that are more than about 5x5. Default is `fft`.
501+
"""
502+
503+
if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0:
504+
raise exc.KernelException("Kernel2D Kernel2D must be odd")
505+
506+
print(type(image.native + blurring_image.native))
507+
print(type(self.native))
508+
509+
convolved_array_2d = scipy.signal.convolve2d((image.native + blurring_image.native)._array, self.native._array, mode="same")
487510

511+
convolved_array_1d = array_2d_util.array_2d_slim_from(
512+
mask_2d=np.array(image.mask),
513+
array_2d_native=convolved_array_2d,
514+
)
515+
516+
return Array2D(values=convolved_array_1d, mask=image.mask)
517+
518+
def convolve_image_jax_from(self, array, blurring_array, method="auto"):
519+
"""
520+
For a given 1D array and blurring array, convolve the two using this convolver.
521+
522+
Parameters
523+
----------
524+
array
525+
1D array of the values which are to be blurred with the convolver's PSF.
526+
blurring_array
527+
1D array of the blurring values which blur into the array after PSF convolution.
528+
jax_method
529+
If JAX is enabled this keyword will indicate what method is used for the PSF
530+
convolution. Can be either `direct` to calculate it in real space or `fft`
531+
to calculated it via a fast Fourier transform. `fft` is typically faster for
532+
kernels that are more than about 5x5. Default is `fft`.
533+
"""
488534
slim_to_native = jnp.nonzero(
489-
jnp.logical_not(self.mask.array), size=image.shape[0]
535+
jnp.logical_not(self.mask.array), size=array.shape[0]
490536
)
491537
slim_to_native_blurring = jnp.nonzero(
492-
jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0]
538+
jnp.logical_not(self.blurring_mask), size=blurring_array.shape[0]
493539
)
494540

495-
expanded_image_native = jnp.zeros(self.mask.shape)
541+
expanded_array_native = jnp.zeros(self.mask.shape)
496542

497-
expanded_image_native = expanded_image_native.at[slim_to_native].set(
498-
image.array
543+
expanded_array_native = expanded_array_native.at[slim_to_native].set(
544+
array.array
499545
)
500-
expanded_image_native = expanded_image_native.at[slim_to_native_blurring].set(
501-
blurring_image.array
546+
expanded_array_native = expanded_array_native.at[slim_to_native_blurring].set(
547+
blurring_array.array
502548
)
503549

504550
kernel = np.array(self.kernel.native.array)
505551

506552
convolve_native = jax.scipy.signal.convolve(
507-
expanded_image_native, kernel, mode="same", method=method
553+
expanded_array_native, kernel, mode="same", method=method
508554
)
509555

510556
convolve_slim = convolve_native[slim_to_native]

test_autoarray/structures/arrays/test_kernel_2d.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from astropy.io import fits
21
from astropy import units
32
from astropy.modeling import functional_models
43
from astropy.coordinates import Angle
4+
import jax.numpy as jnp
55
import numpy as np
66
import pytest
77
from os import path
8-
import os
98

109
import autoarray as aa
1110
from autoarray import exc
@@ -359,6 +358,36 @@ def test__convolved_array_from():
359358
).all()
360359

361360

361+
def test__convolved_array_from__input_jax_array():
362+
363+
array_2d = jnp.array(
364+
[
365+
[0.0, 0.0, 0.0, 0.0],
366+
[1.0, 0.0, 0.0, 0.0],
367+
[0.0, 0.0, 0.0, 1.0],
368+
[0.0, 0.0, 0.0, 0.0],
369+
])
370+
371+
kernel_2d = aa.Kernel2D.no_mask(
372+
values=[[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 3.0, 3.0]], pixel_scales=1.0
373+
)
374+
375+
blurred_array_2d = kernel_2d.convolved_array_from(array_2d)
376+
377+
assert (
378+
blurred_array_2d.native
379+
== np.array(
380+
[
381+
[1.0, 1.0, 0.0, 0.0],
382+
[2.0, 1.0, 1.0, 1.0],
383+
[3.0, 3.0, 2.0, 2.0],
384+
[0.0, 0.0, 1.0, 3.0],
385+
]
386+
)
387+
).all()
388+
389+
390+
362391
def test__convolve_mapping_matrix():
363392
mask = np.array(
364393
[
@@ -482,19 +511,19 @@ def test__convolve_mapping_matrix():
482511
)
483512

484513

485-
def test__compare_to_full_2d_convolution():
486-
# Setup a blurred data, using the PSF to perform the convolution in 2D, then masks it to make a 1d array.
487-
488-
import scipy.signal
514+
def test__convolve_image():
489515

490516
mask = aa.Mask2D.circular(
491517
shape_native=(30, 30), pixel_scales=(1.0, 1.0), radius=4.0
492518
)
493-
kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0)
494-
image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0)
519+
520+
import scipy.signal
521+
522+
kernel = np.arange(49).reshape(7, 7)
523+
image = np.arange(900).reshape(30, 30)
495524

496525
blurred_image_via_scipy = scipy.signal.convolve2d(
497-
image.native, kernel.native, mode="same"
526+
image, kernel, mode="same"
498527
)
499528
blurred_image_via_scipy = aa.Array2D.no_mask(
500529
values=blurred_image_via_scipy, pixel_scales=1.0
@@ -503,7 +532,10 @@ def test__compare_to_full_2d_convolution():
503532
values=blurred_image_via_scipy.native, mask=mask
504533
)
505534

506-
# Now reproduce this data using the frame convolver_image
535+
# Now reproduce this data using the convolve_image function
536+
537+
image = aa.Array2D.no_mask(values=np.arange(900).reshape(30, 30), pixel_scales=1.0)
538+
kernel = aa.Kernel2D.no_mask(values=np.arange(49).reshape(7, 7), pixel_scales=1.0)
507539

508540
masked_image = aa.Array2D(values=image.native, mask=mask)
509541

0 commit comments

Comments
 (0)