|
8 | 8 |
|
9 | 9 | from autoconf.fitsable import header_obj_from |
10 | 10 |
|
11 | | -from autoarray.mask.mask_2d import Mask2D |
12 | 11 | from autoarray.structures.arrays.uniform_2d import AbstractArray2D |
13 | 12 | from autoarray.structures.arrays.uniform_2d import Array2D |
14 | 13 | from autoarray.structures.grids.uniform_2d import Grid2D |
|
17 | 16 | from autoarray import exc |
18 | 17 | from autoarray import type as ty |
19 | 18 | from autoarray.structures.arrays import array_2d_util |
| 19 | +from autoarray.mask.mask_2d import mask_2d_util |
20 | 20 |
|
21 | 21 |
|
22 | 22 | class Kernel2D(AbstractArray2D): |
@@ -483,28 +483,74 @@ def convolved_array_from(self, array: Array2D) -> Array2D: |
483 | 483 |
|
484 | 484 | return Array2D(values=convolved_array_1d, mask=array_2d.mask) |
485 | 485 |
|
486 | | - def jax_convolve(self, image, blurring_image, method="auto"): |
| 486 | + def convolve_image(self, image, blurring_image, jax_method="fft"): |
| 487 | + """ |
| 488 | + For a given 1D array and blurring array, convolve the two using this convolver. |
| 489 | +
|
| 490 | + Parameters |
| 491 | + ---------- |
| 492 | + image |
| 493 | + 1D array of the values which are to be blurred with the convolver's PSF. |
| 494 | + blurring_image |
| 495 | + 1D array of the blurring values which blur into the array after PSF convolution. |
| 496 | + jax_method |
| 497 | + If JAX is enabled this keyword will indicate what method is used for the PSF |
| 498 | + convolution. Can be either `direct` to calculate it in real space or `fft` |
| 499 | + to calculated it via a fast Fourier transform. `fft` is typically faster for |
| 500 | + kernels that are more than about 5x5. Default is `fft`. |
| 501 | + """ |
| 502 | + |
| 503 | + if self.mask.shape[0] % 2 == 0 or self.mask.shape[1] % 2 == 0: |
| 504 | + raise exc.KernelException("Kernel2D Kernel2D must be odd") |
| 505 | + |
| 506 | + print(type(image.native + blurring_image.native)) |
| 507 | + print(type(self.native)) |
| 508 | + |
| 509 | + convolved_array_2d = scipy.signal.convolve2d((image.native + blurring_image.native)._array, self.native._array, mode="same") |
487 | 510 |
|
| 511 | + convolved_array_1d = array_2d_util.array_2d_slim_from( |
| 512 | + mask_2d=np.array(image.mask), |
| 513 | + array_2d_native=convolved_array_2d, |
| 514 | + ) |
| 515 | + |
| 516 | + return Array2D(values=convolved_array_1d, mask=image.mask) |
| 517 | + |
| 518 | + def convolve_image_jax_from(self, array, blurring_array, method="auto"): |
| 519 | + """ |
| 520 | + For a given 1D array and blurring array, convolve the two using this convolver. |
| 521 | +
|
| 522 | + Parameters |
| 523 | + ---------- |
| 524 | + array |
| 525 | + 1D array of the values which are to be blurred with the convolver's PSF. |
| 526 | + blurring_array |
| 527 | + 1D array of the blurring values which blur into the array after PSF convolution. |
| 528 | + jax_method |
| 529 | + If JAX is enabled this keyword will indicate what method is used for the PSF |
| 530 | + convolution. Can be either `direct` to calculate it in real space or `fft` |
| 531 | + to calculated it via a fast Fourier transform. `fft` is typically faster for |
| 532 | + kernels that are more than about 5x5. Default is `fft`. |
| 533 | + """ |
488 | 534 | slim_to_native = jnp.nonzero( |
489 | | - jnp.logical_not(self.mask.array), size=image.shape[0] |
| 535 | + jnp.logical_not(self.mask.array), size=array.shape[0] |
490 | 536 | ) |
491 | 537 | slim_to_native_blurring = jnp.nonzero( |
492 | | - jnp.logical_not(self.blurring_mask), size=blurring_image.shape[0] |
| 538 | + jnp.logical_not(self.blurring_mask), size=blurring_array.shape[0] |
493 | 539 | ) |
494 | 540 |
|
495 | | - expanded_image_native = jnp.zeros(self.mask.shape) |
| 541 | + expanded_array_native = jnp.zeros(self.mask.shape) |
496 | 542 |
|
497 | | - expanded_image_native = expanded_image_native.at[slim_to_native].set( |
498 | | - image.array |
| 543 | + expanded_array_native = expanded_array_native.at[slim_to_native].set( |
| 544 | + array.array |
499 | 545 | ) |
500 | | - expanded_image_native = expanded_image_native.at[slim_to_native_blurring].set( |
501 | | - blurring_image.array |
| 546 | + expanded_array_native = expanded_array_native.at[slim_to_native_blurring].set( |
| 547 | + blurring_array.array |
502 | 548 | ) |
503 | 549 |
|
504 | 550 | kernel = np.array(self.kernel.native.array) |
505 | 551 |
|
506 | 552 | convolve_native = jax.scipy.signal.convolve( |
507 | | - expanded_image_native, kernel, mode="same", method=method |
| 553 | + expanded_array_native, kernel, mode="same", method=method |
508 | 554 | ) |
509 | 555 |
|
510 | 556 | convolve_slim = convolve_native[slim_to_native] |
|
0 commit comments