@@ -723,14 +723,18 @@ def convolved_image_from(
723723 blurred_image_full = xp .fft .irfft2 (
724724 fft_psf * fft_image_native , s = fft_shape , axes = (0 , 1 )
725725 )
726+ ky , kx = self .native .array .shape # (21, 21)
727+ off_y = (ky - 1 ) // 2
728+ off_x = (kx - 1 ) // 2
726729
727- # Crop back to mask_shape
728- start_indices = tuple (
729- (full_size - out_size ) // 2
730- for full_size , out_size in zip (full_shape , mask_shape )
730+ blurred_image_full = xp .roll (
731+ blurred_image_full , shift = (- off_y , - off_x ), axis = (0 , 1 )
731732 )
733+
734+ start_indices = (off_y , off_x )
735+
732736 blurred_image_native = jax .lax .dynamic_slice (
733- blurred_image_full , start_indices , mask_shape
737+ blurred_image_full , start_indices , image . mask . shape
734738 )
735739
736740 # Return slim form; optionally cast for downstream stability
@@ -806,6 +810,10 @@ def convolved_mapping_matrix_from(
806810 ndarray of shape (N_pix, N_src)
807811 Convolved mapping matrix in slim form.
808812 """
813+ # -------------------------------------------------------------------------
814+ # NumPy path unchanged
815+ # -------------------------------------------------------------------------
816+
809817 # -------------------------------------------------------------------------
810818 # NumPy path unchanged
811819 # -------------------------------------------------------------------------
@@ -835,34 +843,24 @@ def convolved_mapping_matrix_from(
835843 import jax .numpy as jnp
836844
837845 # -------------------------------------------------------------------------
838- # Validate cached FFT shapes / state
846+ # Cached FFT shapes/ state (REQUIRED)
839847 # -------------------------------------------------------------------------
840848 if self .fft_shape is None :
841- full_shape , fft_shape , mask_shape = self .fft_shape_from (mask = mask )
842849 raise ValueError (
843- f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n "
844- f"Expected mapping matrix padded to match FFT shape of PSF.\n "
845- f"PSF fft_shape: { fft_shape } , mask shape: { mask .shape } , "
846- f"mapping_matrix shape: { getattr (mapping_matrix , 'shape' , 'unknown' )} ."
850+ "FFT convolution requires precomputed FFT shapes on the PSF."
847851 )
848- else :
849- fft_shape = self .fft_shape
850- full_shape = self .full_shape
851- mask_shape = self .mask_shape
852- fft_psf_mapping = self .fft_psf_mapping
852+
853+ fft_shape = self .fft_shape
854+ fft_psf_mapping = self .fft_psf_mapping
853855
854856 # -------------------------------------------------------------------------
855- # Mixed precision dtypes (JAX only)
857+ # Mixed precision handling
856858 # -------------------------------------------------------------------------
857859 fft_complex_dtype = jnp .complex64 if use_mixed_precision else jnp .complex128
858-
859- # Ensure PSF FFT dtype matches the FFT path
860860 fft_psf_mapping = jnp .asarray (fft_psf_mapping , dtype = fft_complex_dtype )
861861
862862 # -------------------------------------------------------------------------
863- # Build native cube in the FFT dtype (THIS IS THE KEY)
864- # This relies on mapping_matrix_native_from honoring the use_mixed_precision
865- # kwarg when constructing the native mapping matrix.
863+ # Build native cube on the *native mask grid*
866864 # -------------------------------------------------------------------------
867865 mapping_matrix_native = self .mapping_matrix_native_from (
868866 mapping_matrix = mapping_matrix ,
@@ -872,35 +870,51 @@ def convolved_mapping_matrix_from(
872870 use_mixed_precision = use_mixed_precision ,
873871 xp = xp ,
874872 )
873+ # shape: (ny_native, nx_native, n_src)
875874
876875 # -------------------------------------------------------------------------
877876 # FFT convolution
878877 # -------------------------------------------------------------------------
879878 fft_mapping_matrix_native = xp .fft .rfft2 (
880879 mapping_matrix_native , s = fft_shape , axes = (0 , 1 )
881880 )
881+
882882 blurred_mapping_matrix_full = xp .fft .irfft2 (
883883 fft_psf_mapping * fft_mapping_matrix_native ,
884884 s = fft_shape ,
885885 axes = (0 , 1 ),
886886 )
887887
888888 # -------------------------------------------------------------------------
889- # Crop back to mask-shape
889+ # APPLY SAME FIX AS convolved_image_from
890890 # -------------------------------------------------------------------------
891- start_indices = tuple (
892- (full_size - out_size ) // 2
893- for full_size , out_size in zip (full_shape , mask_shape )
894- ) + (0 ,)
895- out_shape_full = mask_shape + (blurred_mapping_matrix_full .shape [2 ],)
891+ ky , kx = self .native .array .shape
892+ off_y = (ky - 1 ) // 2
893+ off_x = (kx - 1 ) // 2
894+
895+ blurred_mapping_matrix_full = xp .roll (
896+ blurred_mapping_matrix_full ,
897+ shift = (- off_y , - off_x ),
898+ axis = (0 , 1 ),
899+ )
900+
901+ # -------------------------------------------------------------------------
902+ # Extract native grid (same as image path)
903+ # -------------------------------------------------------------------------
904+ native_shape = mask .shape
905+ start_indices = (off_y , off_x , 0 )
906+
907+ out_shape = native_shape + (blurred_mapping_matrix_full .shape [2 ],)
896908
897909 blurred_mapping_matrix_native = jax .lax .dynamic_slice (
898910 blurred_mapping_matrix_full ,
899911 start_indices ,
900- out_shape_full ,
912+ out_shape ,
901913 )
902914
903- # Return slim form
915+ # -------------------------------------------------------------------------
916+ # Slim using ORIGINAL mask indices (same grid)
917+ # -------------------------------------------------------------------------
904918 blurred_slim = blurred_mapping_matrix_native [mask .slim_to_native_tuple ]
905919
906920 return blurred_slim
0 commit comments