@@ -54,6 +54,9 @@ def __init__(
5454 store_native = store_native ,
5555 )
5656
57+ if self .mask .shape [0 ] % 2 == 0 or self .mask .shape [1 ] % 2 == 0 :
58+ raise exc .KernelException ("Kernel2D Kernel2D must be odd" )
59+
5760 if normalize :
5861 self ._array = np .divide (self ._array , np .sum (self ._array ))
5962
@@ -500,59 +503,28 @@ def convolve_image(self, image, blurring_image, jax_method="fft"):
500503 kernels that are more than about 5x5. Default is `fft`.
501504 """
502505
503- if self .mask .shape [0 ] % 2 == 0 or self .mask .shape [1 ] % 2 == 0 :
504- raise exc .KernelException ("Kernel2D Kernel2D must be odd" )
505-
506- print (type (image .native + blurring_image .native ))
507- print (type (self .native ))
508-
509- convolved_array_2d = scipy .signal .convolve2d ((image .native + blurring_image .native )._array , self .native ._array , mode = "same" )
510-
511- convolved_array_1d = array_2d_util .array_2d_slim_from (
512- mask_2d = np .array (image .mask ),
513- array_2d_native = convolved_array_2d ,
514- )
515-
516- return Array2D (values = convolved_array_1d , mask = image .mask )
517-
518- def convolve_image_jax_from (self , array , blurring_array , method = "auto" ):
519- """
520- For a given 1D array and blurring array, convolve the two using this convolver.
521-
522- Parameters
523- ----------
524- array
525- 1D array of the values which are to be blurred with the convolver's PSF.
526- blurring_array
527- 1D array of the blurring values which blur into the array after PSF convolution.
528- jax_method
529- If JAX is enabled this keyword will indicate what method is used for the PSF
530- convolution. Can be either `direct` to calculate it in real space or `fft`
531- to calculated it via a fast Fourier transform. `fft` is typically faster for
532- kernels that are more than about 5x5. Default is `fft`.
533- """
534506 slim_to_native = jnp .nonzero (
535- jnp .logical_not (self .mask .array ), size = array .shape [0 ]
507+ jnp .logical_not (image .mask .array ), size = image .shape [0 ]
536508 )
537509 slim_to_native_blurring = jnp .nonzero (
538- jnp .logical_not (self . blurring_mask ), size = blurring_array .shape [0 ]
510+ jnp .logical_not (blurring_image . mask . array ), size = blurring_image .shape [0 ]
539511 )
540512
541- expanded_array_native = jnp .zeros (self .mask .shape )
513+ expanded_array_native = jnp .zeros (image .mask .shape )
542514
543515 expanded_array_native = expanded_array_native .at [slim_to_native ].set (
544- array .array
516+ image .array
545517 )
546518 expanded_array_native = expanded_array_native .at [slim_to_native_blurring ].set (
547- blurring_array .array
519+ blurring_image .array
548520 )
549521
550- kernel = np .array (self .kernel . native .array )
522+ kernel = np .array (self .native .array )
551523
552524 convolve_native = jax .scipy .signal .convolve (
553- expanded_array_native , kernel , mode = "same" , method = method
525+ expanded_array_native , kernel , mode = "same" , method = jax_method
554526 )
555527
556- convolve_slim = convolve_native [slim_to_native ]
528+ convolved_array_1d = convolve_native [slim_to_native ]
557529
558- return convolve_slim
530+ return Array2D ( values = convolved_array_1d , mask = image . mask )
0 commit comments