Skip to content

Commit 6d4848e

Browse files
Jammy2211Jammy2211
authored andcommitted
fully tested bug fix
1 parent b2dd830 commit 6d4848e

3 files changed

Lines changed: 47 additions & 35 deletions

File tree

autoarray/dataset/imaging/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,9 @@ def __init__(
189189
self.psf = psf
190190

191191
if psf is not None:
192-
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
193-
raise exc.KernelException("Kernel2D Kernel2D must be odd")
192+
if not psf.use_fft:
193+
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
194+
raise exc.KernelException("Kernel2D Kernel2D must be odd")
194195

195196
self.grids = GridsDataset(
196197
mask=self.data.mask,

autoarray/mask/derive/mask_2d.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,6 @@ def blurring_from(self, kernel_shape_native: Tuple[int, int]) -> Mask2D:
198198

199199
from autoarray.mask.mask_2d import Mask2D
200200

201-
if kernel_shape_native[0] % 2 == 0 or kernel_shape_native[1] % 2 == 0:
202-
raise exc.MaskException("psf_size of exterior region must be odd")
203-
204201
blurring_mask = mask_2d_util.blurring_mask_2d_from(
205202
mask_2d=self.mask,
206203
kernel_shape_native=kernel_shape_native,

autoarray/structures/arrays/kernel_2d.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -723,14 +723,18 @@ def convolved_image_from(
723723
blurred_image_full = xp.fft.irfft2(
724724
fft_psf * fft_image_native, s=fft_shape, axes=(0, 1)
725725
)
726+
ky, kx = self.native.array.shape # (21, 21)
727+
off_y = (ky - 1) // 2
728+
off_x = (kx - 1) // 2
726729

727-
# Crop back to mask_shape
728-
start_indices = tuple(
729-
(full_size - out_size) // 2
730-
for full_size, out_size in zip(full_shape, mask_shape)
730+
blurred_image_full = xp.roll(
731+
blurred_image_full, shift=(-off_y, -off_x), axis=(0, 1)
731732
)
733+
734+
start_indices = (off_y, off_x)
735+
732736
blurred_image_native = jax.lax.dynamic_slice(
733-
blurred_image_full, start_indices, mask_shape
737+
blurred_image_full, start_indices, image.mask.shape
734738
)
735739

736740
# Return slim form; optionally cast for downstream stability
@@ -806,6 +810,10 @@ def convolved_mapping_matrix_from(
806810
ndarray of shape (N_pix, N_src)
807811
Convolved mapping matrix in slim form.
808812
"""
813+
# -------------------------------------------------------------------------
814+
# NumPy path unchanged
815+
# -------------------------------------------------------------------------
816+
809817
# -------------------------------------------------------------------------
810818
# NumPy path unchanged
811819
# -------------------------------------------------------------------------
@@ -835,34 +843,24 @@ def convolved_mapping_matrix_from(
835843
import jax.numpy as jnp
836844

837845
# -------------------------------------------------------------------------
838-
# Validate cached FFT shapes / state
846+
# Cached FFT shapes/state (REQUIRED)
839847
# -------------------------------------------------------------------------
840848
if self.fft_shape is None:
841-
full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=mask)
842849
raise ValueError(
843-
f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n"
844-
f"Expected mapping matrix padded to match FFT shape of PSF.\n"
845-
f"PSF fft_shape: {fft_shape}, mask shape: {mask.shape}, "
846-
f"mapping_matrix shape: {getattr(mapping_matrix, 'shape', 'unknown')}."
850+
"FFT convolution requires precomputed FFT shapes on the PSF."
847851
)
848-
else:
849-
fft_shape = self.fft_shape
850-
full_shape = self.full_shape
851-
mask_shape = self.mask_shape
852-
fft_psf_mapping = self.fft_psf_mapping
852+
853+
fft_shape = self.fft_shape
854+
fft_psf_mapping = self.fft_psf_mapping
853855

854856
# -------------------------------------------------------------------------
855-
# Mixed precision dtypes (JAX only)
857+
# Mixed precision handling
856858
# -------------------------------------------------------------------------
857859
fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128
858-
859-
# Ensure PSF FFT dtype matches the FFT path
860860
fft_psf_mapping = jnp.asarray(fft_psf_mapping, dtype=fft_complex_dtype)
861861

862862
# -------------------------------------------------------------------------
863-
# Build native cube in the FFT dtype (THIS IS THE KEY)
864-
# This relies on mapping_matrix_native_from honoring the use_mixed_precision
865-
# kwarg when constructing the native mapping matrix.
863+
# Build native cube on the *native mask grid*
866864
# -------------------------------------------------------------------------
867865
mapping_matrix_native = self.mapping_matrix_native_from(
868866
mapping_matrix=mapping_matrix,
@@ -872,35 +870,51 @@ def convolved_mapping_matrix_from(
872870
use_mixed_precision=use_mixed_precision,
873871
xp=xp,
874872
)
873+
# shape: (ny_native, nx_native, n_src)
875874

876875
# -------------------------------------------------------------------------
877876
# FFT convolution
878877
# -------------------------------------------------------------------------
879878
fft_mapping_matrix_native = xp.fft.rfft2(
880879
mapping_matrix_native, s=fft_shape, axes=(0, 1)
881880
)
881+
882882
blurred_mapping_matrix_full = xp.fft.irfft2(
883883
fft_psf_mapping * fft_mapping_matrix_native,
884884
s=fft_shape,
885885
axes=(0, 1),
886886
)
887887

888888
# -------------------------------------------------------------------------
889-
# Crop back to mask-shape
889+
# APPLY SAME FIX AS convolved_image_from
890890
# -------------------------------------------------------------------------
891-
start_indices = tuple(
892-
(full_size - out_size) // 2
893-
for full_size, out_size in zip(full_shape, mask_shape)
894-
) + (0,)
895-
out_shape_full = mask_shape + (blurred_mapping_matrix_full.shape[2],)
891+
ky, kx = self.native.array.shape
892+
off_y = (ky - 1) // 2
893+
off_x = (kx - 1) // 2
894+
895+
blurred_mapping_matrix_full = xp.roll(
896+
blurred_mapping_matrix_full,
897+
shift=(-off_y, -off_x),
898+
axis=(0, 1),
899+
)
900+
901+
# -------------------------------------------------------------------------
902+
# Extract native grid (same as image path)
903+
# -------------------------------------------------------------------------
904+
native_shape = mask.shape
905+
start_indices = (off_y, off_x, 0)
906+
907+
out_shape = native_shape + (blurred_mapping_matrix_full.shape[2],)
896908

897909
blurred_mapping_matrix_native = jax.lax.dynamic_slice(
898910
blurred_mapping_matrix_full,
899911
start_indices,
900-
out_shape_full,
912+
out_shape,
901913
)
902914

903-
# Return slim form
915+
# -------------------------------------------------------------------------
916+
# Slim using ORIGINAL mask indices (same grid)
917+
# -------------------------------------------------------------------------
904918
blurred_slim = blurred_mapping_matrix_native[mask.slim_to_native_tuple]
905919

906920
return blurred_slim

0 commit comments

Comments
 (0)