@@ -23,16 +23,36 @@ class NUFFTPlaceholder:
2323from autoarray .operators import transformer_util
2424
2525
26+ try :
27+ import nufftax as _nufftax
28+ except ModuleNotFoundError :
29+ _nufftax = None
30+
31+
2632def pynufft_exception ():
2733 raise ModuleNotFoundError (
2834 "\n --------------------\n "
29- "You are attempting to perform interferometer analysis.\n \n "
35+ "You are attempting to perform interferometer analysis with the legacy "
36+ "pynufft-backed `TransformerNUFFTPyNUFFT`.\n \n "
3037 "However, the optional library PyNUFFT (https://github.com/jyhmiinlin/pynufft) is not installed.\n \n "
3138 "Install it via the command `pip install pynufft==2022.2.2`.\n \n "
3239 "----------------------"
3340 )
3441
3542
43+ def nufftax_exception ():
44+ raise ModuleNotFoundError (
45+ "\n --------------------\n "
46+ "You are attempting to perform interferometer analysis with the default "
47+ "JAX-native `TransformerNUFFT`.\n \n "
48+ "However, the optional library nufftax (https://github.com/GragasLab/nufftax) is not installed.\n \n "
49+ "Install it via the command `pip install nufftax`.\n \n "
50+ "If you want to use the legacy pynufft backend instead, pass "
51+ "`transformer_class=TransformerNUFFTPyNUFFT` and install pynufft.\n \n "
52+ "----------------------"
53+ )
54+
55+
3656class TransformerDFT :
3757 def __init__ (
3858 self ,
@@ -175,13 +195,18 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
175195 )
176196
177197
178- class TransformerNUFFT (NUFFT_cpu ):
198+ class TransformerNUFFTPyNUFFT (NUFFT_cpu ):
179199 def __init__ (
180200 self , uv_wavelengths : np .ndarray , real_space_mask : Mask2D , xp = np , ** kwargs
181201 ):
182202 """
183203 Performs the Non-Uniform Fast Fourier Transform (NUFFT) for interferometric image reconstruction.
184204
205+ Legacy pynufft-backed transformer. The default `TransformerNUFFT` is now backed by `nufftax`
206+ (JAX-native, differentiable, ~zero gridding error) — this class is retained so users who depend
207+ on pynufft's specific gridding behaviour can opt in by passing
208+ `transformer_class=TransformerNUFFTPyNUFFT`.
209+
185210 This transformer uses the PyNUFFT library to efficiently compute the Fourier transform
186211 of an image defined on a regular real-space grid to a set of non-uniform uv-plane (Fourier space)
187212 coordinates, as is typical in radio interferometry.
@@ -226,7 +251,7 @@ def __init__(
226251 if isinstance (self , NUFFTPlaceholder ):
227252 pynufft_exception ()
228253
229- super (TransformerNUFFT , self ).__init__ ()
254+ super (TransformerNUFFTPyNUFFT , self ).__init__ ()
230255
231256 self .uv_wavelengths = uv_wavelengths
232257 self .real_space_mask = real_space_mask
@@ -469,3 +494,212 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
469494 )
470495
471496 return transformed_mapping_matrix
497+
498+
499+ class TransformerNUFFT :
500+ def __init__ (
501+ self ,
502+ uv_wavelengths : np .ndarray ,
503+ real_space_mask : Mask2D ,
504+ eps : float = 1e-12 ,
505+ xp = np ,
506+ ** kwargs ,
507+ ):
508+ """
509+ JAX-native Non-Uniform FFT for image -> visibilities, backed by `nufftax`.
510+
511+ This is the default `TransformerNUFFT` in PyAutoArray. It uses the
512+ `nufftax` library (https://github.com/GragasLab/nufftax), a pure-JAX
513+ NUFFT implementation that supports `jax.jit`, `jax.grad`, and
514+ `jax.vmap`. It replaces the legacy `TransformerNUFFTPyNUFFT` (which
515+ wraps the non-differentiable `pynufft` library) as the default backend.
516+
517+ Convention recipe (matches `TransformerDFT` to ~1e-13 relative across
518+ odd/even/non-square image sizes):
519+
520+ image_flipped = image[::-1, :]
521+ x = 2 * pi * u_lambda * pixel_scale_rad
522+ y = 2 * pi * v_lambda * pixel_scale_rad
523+ offset_x = 0.5 if N_x is even else 0.0
524+ offset_y = 0.5 if N_y is even else 0.0
525+ shift = exp(-i * (offset_x * x + offset_y * y))
526+ visibilities = nufftax.nufft2d2(x, y, image_flipped, eps, -1) * shift
527+
528+ The `shift` factor is the half-pixel correction between autoarray's
529+ grid centre at index `(N - 1) / 2` and nufftax's mode-0 at index
530+ `N // 2`; pynufft applies this internally, nufftax does not.
531+
532+ Parameters
533+ ----------
534+ uv_wavelengths
535+ The (u, v) coordinates of the measured visibilities in wavelengths,
536+ shape `(n_vis, 2)`.
537+ real_space_mask
538+ The 2D mask defining the real-space image grid.
539+ eps
540+ Requested NUFFT precision passed to nufftax. Defaults to `1e-12`
541+ (effectively machine precision); relax to `1e-9` or `1e-6` for
542+ faster execution if marginal accuracy is acceptable.
543+ xp
544+ Accepted for signature compatibility with the legacy class; not
545+ stored. The active backend is selected per-call via the `xp`
546+ argument to `visibilities_from` / `image_from`.
547+
548+ Attributes
549+ ----------
550+ grid
551+ The real-space pixel grid in radians (computed from the mask).
552+ total_visibilities
553+ Number of measured visibilities.
554+ total_image_pixels
555+ Number of unmasked pixels in the image grid.
556+ adjoint_scaling
557+ Scaling factor available for callers who want to apply an
558+ optional normalisation to the adjoint output. Provided for
559+ parity with the legacy class.
560+ """
561+ from astropy import units
562+
563+ if _nufftax is None :
564+ nufftax_exception ()
565+
566+ self .uv_wavelengths = uv_wavelengths .astype ("float" )
567+ self .real_space_mask = real_space_mask
568+ self .grid = Grid2D .from_mask (mask = self .real_space_mask ).in_radians
569+ self .eps = eps
570+ self .native_index_for_slim_index = copy .copy (
571+ real_space_mask .derive_indexes .native_for_slim .astype ("int" )
572+ )
573+
574+ pixel_scale_rad = self .grid .pixel_scales [0 ] * units .arcsec .to (units .rad )
575+ # nufft2d2 frequency arguments:
576+ # x is paired with the column-axis mode (image x)
577+ # y is paired with the row-axis mode (image y)
578+ # Both must lie in [-pi, pi); the 2*pi*Δ_rad scaling makes uv_lambda
579+ # land in that range for any sane uv-coverage.
580+ self ._x = 2.0 * np .pi * self .uv_wavelengths [:, 0 ] * pixel_scale_rad
581+ self ._y = 2.0 * np .pi * self .uv_wavelengths [:, 1 ] * pixel_scale_rad
582+
583+ n_y , n_x = self .real_space_mask .shape_native
584+ offset_x = 0.5 if n_x % 2 == 0 else 0.0
585+ offset_y = 0.5 if n_y % 2 == 0 else 0.0
586+ self ._shift = np .exp (- 1j * (offset_x * self ._x + offset_y * self ._y ))
587+
588+ self .total_visibilities = uv_wavelengths .shape [0 ]
589+ self .total_image_pixels = real_space_mask .pixels_in_mask
590+ self .adjoint_scaling = (2.0 * n_y ) * (2.0 * n_x )
591+
592+ def _forward_native (self , image_native_2d , xp = np ):
593+ """Run nufft2d2 on a 2D native-shape image array, returning visibilities."""
594+ if xp .__name__ .startswith ("jax" ):
595+ import jax .numpy as jnp
596+
597+ img = jnp .asarray (image_native_2d )[::- 1 , :].astype (jnp .complex128 )
598+ x = jnp .asarray (self ._x )
599+ y = jnp .asarray (self ._y )
600+ shift = jnp .asarray (self ._shift )
601+ return _nufftax .nufft2d2 (x , y , img , self .eps , - 1 ) * shift
602+
603+ img = image_native_2d [::- 1 , :].astype (np .complex128 )
604+ out = _nufftax .nufft2d2 (self ._x , self ._y , img , self .eps , - 1 ) * self ._shift
605+ return np .array (np .asarray (out ))
606+
607+ def visibilities_from (self , image , xp = np ) -> Visibilities :
608+ """
609+ Forward NUFFT: real-space image -> visibilities at the configured uv points.
610+
611+ For numpy callers (`xp=np`) the result is materialised back to numpy
612+ before being wrapped in `Visibilities`. For JAX callers (`xp=jnp`)
613+ the result stays as a `jax.Array` so it can flow through `jax.jit`
614+ / `jax.grad` / `jax.vmap` without device round-trips.
615+ """
616+ if xp .__name__ .startswith ("jax" ):
617+ import jax .numpy as jnp
618+
619+ image_native = jnp .zeros (image .mask .shape , dtype = image .dtype )
620+ image_native = image_native .at [image .mask .slim_to_native_tuple ].set (
621+ image .array
622+ )
623+ else :
624+ image_native = image .native .array
625+
626+ return Visibilities (visibilities = self ._forward_native (image_native , xp = xp ))
627+
628+ def image_from (
629+ self ,
630+ visibilities : Visibilities ,
631+ use_adjoint_scaling : bool = False ,
632+ xp = np ,
633+ ) -> Array2D :
634+ """
635+ Adjoint NUFFT: visibilities -> real-space (dirty) image.
636+
637+ Implemented as `nufftax.nufft2d1` with `conj(shift)` applied to the
638+ visibilities and a final row-flip to return to autoarray's native
639+ orientation. The real part is taken to discard imaginary residue,
640+ matching the legacy class' behaviour.
641+
642+ Note that this is the **mathematical adjoint** of `visibilities_from`,
643+ with no kernel deconvolution applied. The dirty image therefore
644+ differs in absolute scale from the legacy `TransformerNUFFTPyNUFFT`
645+ adjoint (which applies pynufft's internal IFFT and kernel
646+ deconvolution). The structure of the dirty image is the same, and
647+ the values match `TransformerDFT.image_from` exactly.
648+
649+ **Scale-sensitive callers**: `Interferometer.apply_sparse_operator`
650+ consumes the dirty-image scale together with a precision operator; it
651+ is currently incompatible with this class and raises
652+ `NotImplementedError`. Use `TransformerDFT` or
653+ `TransformerNUFFTPyNUFFT` if you need the sparse-operator path.
654+ """
655+ n_y , n_x = self .real_space_mask .shape_native
656+ n_modes = (n_x , n_y ) # nufftax wants (n1, n2) = (N_x, N_y)
657+
658+ if xp .__name__ .startswith ("jax" ):
659+ import jax .numpy as jnp
660+
661+ x = jnp .asarray (self ._x )
662+ y = jnp .asarray (self ._y )
663+ shift_conj = jnp .asarray (np .conj (self ._shift ))
664+ c = jnp .asarray (visibilities .array ) * shift_conj
665+ f = _nufftax .nufft2d1 (x , y , c , n_modes , self .eps , + 1 )
666+ image = jnp .real (f )[::- 1 , :]
667+ else :
668+ c = visibilities .array * np .conj (self ._shift )
669+ f = _nufftax .nufft2d1 (self ._x , self ._y , c , n_modes , self .eps , + 1 )
670+ image = np .array (np .asarray (f )[::- 1 , :].real )
671+
672+ if use_adjoint_scaling :
673+ image = image * self .adjoint_scaling
674+
675+ return Array2D (values = image , mask = self .real_space_mask )
676+
677+ def transform_mapping_matrix (self , mapping_matrix , xp = np ):
678+ """
679+ Apply the forward NUFFT to each column of a mapping matrix.
680+
681+ Each column is scattered back to the native 2D image grid using the
682+ mask's `slim_to_native_tuple`, then passed through `_forward_native`.
683+ """
684+ n_uv = self .uv_wavelengths .shape [0 ]
685+ n_src = mapping_matrix .shape [1 ]
686+ slim_to_native = self .real_space_mask .slim_to_native_tuple
687+ native_shape = self .real_space_mask .shape_native
688+
689+ if xp .__name__ .startswith ("jax" ):
690+ import jax .numpy as jnp
691+
692+ out = jnp .zeros ((n_uv , n_src ), dtype = jnp .complex128 )
693+ for k in range (n_src ):
694+ image_2d = jnp .zeros (native_shape , dtype = mapping_matrix .dtype )
695+ image_2d = image_2d .at [slim_to_native ].set (mapping_matrix [:, k ])
696+ vis = self ._forward_native (image_2d , xp = xp )
697+ out = out .at [:, k ].set (vis )
698+ return out
699+
700+ out = np .zeros ((n_uv , n_src ), dtype = np .complex128 )
701+ for k in range (n_src ):
702+ image_2d = np .zeros (native_shape , dtype = mapping_matrix .dtype )
703+ image_2d [slim_to_native ] = mapping_matrix [:, k ]
704+ out [:, k ] = self ._forward_native (image_2d , xp = xp )
705+ return out
0 commit comments