Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .mask.mask_2d import Mask2D
from .operators.transformer import TransformerDFT
from .operators.transformer import TransformerNUFFT
from .operators.transformer import TransformerNUFFTPyNUFFT
from .operators.over_sampling.decorator import over_sample
from .operators.contour import Grid2DContour
from .layout.layout import Layout1D
Expand Down
23 changes: 23 additions & 0 deletions autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,29 @@ def apply_sparse_operator(
enabling efficient pixelized source reconstruction via the sparse linear algebra formalism.
"""

if isinstance(self.transformer, TransformerNUFFT):
raise NotImplementedError(
"\n--------------------\n"
"`apply_sparse_operator` is not yet supported with the default "
"`TransformerNUFFT` (nufftax-backed) transformer.\n\n"
"The sparse-operator path consumes the dirty image returned by "
"`transformer.image_from(use_adjoint_scaling=True)` together with "
"the NUFFT precision operator; their relative scale matters. The "
"new `TransformerNUFFT` returns the strict mathematical adjoint "
"(matching `TransformerDFT`), whereas the legacy pynufft adjoint "
"applies an internal Kaiser-Bessel kernel deconvolution. The two "
"scales differ by a non-constant factor, so feeding the new "
"dirty image into the existing sparse-operator solver would "
"silently give wrong answers.\n\n"
"Workarounds:\n"
" - Build the dataset with `transformer_class=TransformerDFT` "
"(the JAX-likelihood scripts do this today), or\n"
" - Build the dataset with "
"`transformer_class=TransformerNUFFTPyNUFFT` to keep the legacy "
"pynufft adjoint scale (requires `pip install pynufft`).\n"
"----------------------"
)

if nufft_precision_operator is None:

logger.info(
Expand Down
240 changes: 237 additions & 3 deletions autoarray/operators/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,36 @@ class NUFFTPlaceholder:
from autoarray.operators import transformer_util


try:
import nufftax as _nufftax
except ModuleNotFoundError:
_nufftax = None


def pynufft_exception():
raise ModuleNotFoundError(
"\n--------------------\n"
"You are attempting to perform interferometer analysis.\n\n"
"You are attempting to perform interferometer analysis with the legacy "
"pynufft-backed `TransformerNUFFTPyNUFFT`.\n\n"
"However, the optional library PyNUFFT (https://github.com/jyhmiinlin/pynufft) is not installed.\n\n"
"Install it via the command `pip install pynufft==2022.2.2`.\n\n"
"----------------------"
)


def nufftax_exception():
raise ModuleNotFoundError(
"\n--------------------\n"
"You are attempting to perform interferometer analysis with the default "
"JAX-native `TransformerNUFFT`.\n\n"
"However, the optional library nufftax (https://github.com/GragasLab/nufftax) is not installed.\n\n"
"Install it via the command `pip install nufftax`.\n\n"
"If you want to use the legacy pynufft backend instead, pass "
"`transformer_class=TransformerNUFFTPyNUFFT` and install pynufft.\n\n"
"----------------------"
)


class TransformerDFT:
def __init__(
self,
Expand Down Expand Up @@ -175,13 +195,18 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
)


class TransformerNUFFT(NUFFT_cpu):
class TransformerNUFFTPyNUFFT(NUFFT_cpu):
def __init__(
self, uv_wavelengths: np.ndarray, real_space_mask: Mask2D, xp=np, **kwargs
):
"""
Performs the Non-Uniform Fast Fourier Transform (NUFFT) for interferometric image reconstruction.

Legacy pynufft-backed transformer. The default `TransformerNUFFT` is now backed by `nufftax`
(JAX-native, differentiable, ~zero gridding error) — this class is retained so users who depend
on pynufft's specific gridding behaviour can opt in by passing
`transformer_class=TransformerNUFFTPyNUFFT`.

This transformer uses the PyNUFFT library to efficiently compute the Fourier transform
of an image defined on a regular real-space grid to a set of non-uniform uv-plane (Fourier space)
coordinates, as is typical in radio interferometry.
Expand Down Expand Up @@ -226,7 +251,7 @@ def __init__(
if isinstance(self, NUFFTPlaceholder):
pynufft_exception()

super(TransformerNUFFT, self).__init__()
super(TransformerNUFFTPyNUFFT, self).__init__()

self.uv_wavelengths = uv_wavelengths
self.real_space_mask = real_space_mask
Expand Down Expand Up @@ -469,3 +494,212 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
)

return transformed_mapping_matrix


class TransformerNUFFT:
def __init__(
self,
uv_wavelengths: np.ndarray,
real_space_mask: Mask2D,
eps: float = 1e-12,
xp=np,
**kwargs,
):
"""
JAX-native Non-Uniform FFT for image -> visibilities, backed by `nufftax`.

This is the default `TransformerNUFFT` in PyAutoArray. It uses the
`nufftax` library (https://github.com/GragasLab/nufftax), a pure-JAX
NUFFT implementation that supports `jax.jit`, `jax.grad`, and
`jax.vmap`. It replaces the legacy `TransformerNUFFTPyNUFFT` (which
wraps the non-differentiable `pynufft` library) as the default backend.

Convention recipe (matches `TransformerDFT` to ~1e-13 relative across
odd/even/non-square image sizes):

image_flipped = image[::-1, :]
x = 2 * pi * u_lambda * pixel_scale_rad
y = 2 * pi * v_lambda * pixel_scale_rad
offset_x = 0.5 if N_x is even else 0.0
offset_y = 0.5 if N_y is even else 0.0
shift = exp(-i * (offset_x * x + offset_y * y))
visibilities = nufftax.nufft2d2(x, y, image_flipped, eps, -1) * shift

The `shift` factor is the half-pixel correction between autoarray's
grid centre at index `(N - 1) / 2` and nufftax's mode-0 at index
`N // 2`; pynufft applies this internally, nufftax does not.

Parameters
----------
uv_wavelengths
The (u, v) coordinates of the measured visibilities in wavelengths,
shape `(n_vis, 2)`.
real_space_mask
The 2D mask defining the real-space image grid.
eps
Requested NUFFT precision passed to nufftax. Defaults to `1e-12`
(effectively machine precision); relax to `1e-9` or `1e-6` for
faster execution if marginal accuracy is acceptable.
xp
Accepted for signature compatibility with the legacy class; not
stored. The active backend is selected per-call via the `xp`
argument to `visibilities_from` / `image_from`.

Attributes
----------
grid
The real-space pixel grid in radians (computed from the mask).
total_visibilities
Number of measured visibilities.
total_image_pixels
Number of unmasked pixels in the image grid.
adjoint_scaling
Scaling factor available for callers who want to apply an
optional normalisation to the adjoint output. Provided for
parity with the legacy class.
"""
from astropy import units

if _nufftax is None:
nufftax_exception()

self.uv_wavelengths = uv_wavelengths.astype("float")
self.real_space_mask = real_space_mask
self.grid = Grid2D.from_mask(mask=self.real_space_mask).in_radians
self.eps = eps
self.native_index_for_slim_index = copy.copy(
real_space_mask.derive_indexes.native_for_slim.astype("int")
)

pixel_scale_rad = self.grid.pixel_scales[0] * units.arcsec.to(units.rad)
# nufft2d2 frequency arguments:
# x is paired with the column-axis mode (image x)
# y is paired with the row-axis mode (image y)
# Both must lie in [-pi, pi); the 2*pi*Δ_rad scaling makes uv_lambda
# land in that range for any sane uv-coverage.
self._x = 2.0 * np.pi * self.uv_wavelengths[:, 0] * pixel_scale_rad
self._y = 2.0 * np.pi * self.uv_wavelengths[:, 1] * pixel_scale_rad

n_y, n_x = self.real_space_mask.shape_native
offset_x = 0.5 if n_x % 2 == 0 else 0.0
offset_y = 0.5 if n_y % 2 == 0 else 0.0
self._shift = np.exp(-1j * (offset_x * self._x + offset_y * self._y))

self.total_visibilities = uv_wavelengths.shape[0]
self.total_image_pixels = real_space_mask.pixels_in_mask
self.adjoint_scaling = (2.0 * n_y) * (2.0 * n_x)

def _forward_native(self, image_native_2d, xp=np):
"""Run nufft2d2 on a 2D native-shape image array, returning visibilities."""
if xp.__name__.startswith("jax"):
import jax.numpy as jnp

img = jnp.asarray(image_native_2d)[::-1, :].astype(jnp.complex128)
x = jnp.asarray(self._x)
y = jnp.asarray(self._y)
shift = jnp.asarray(self._shift)
return _nufftax.nufft2d2(x, y, img, self.eps, -1) * shift

img = image_native_2d[::-1, :].astype(np.complex128)
out = _nufftax.nufft2d2(self._x, self._y, img, self.eps, -1) * self._shift
return np.array(np.asarray(out))

def visibilities_from(self, image, xp=np) -> Visibilities:
"""
Forward NUFFT: real-space image -> visibilities at the configured uv points.

For numpy callers (`xp=np`) the result is materialised back to numpy
before being wrapped in `Visibilities`. For JAX callers (`xp=jnp`)
the result stays as a `jax.Array` so it can flow through `jax.jit`
/ `jax.grad` / `jax.vmap` without device round-trips.
"""
if xp.__name__.startswith("jax"):
import jax.numpy as jnp

image_native = jnp.zeros(image.mask.shape, dtype=image.dtype)
image_native = image_native.at[image.mask.slim_to_native_tuple].set(
image.array
)
else:
image_native = image.native.array

return Visibilities(visibilities=self._forward_native(image_native, xp=xp))

def image_from(
self,
visibilities: Visibilities,
use_adjoint_scaling: bool = False,
xp=np,
) -> Array2D:
"""
Adjoint NUFFT: visibilities -> real-space (dirty) image.

Implemented as `nufftax.nufft2d1` with `conj(shift)` applied to the
visibilities and a final row-flip to return to autoarray's native
orientation. The real part is taken to discard imaginary residue,
matching the legacy class' behaviour.

Note that this is the **mathematical adjoint** of `visibilities_from`,
with no kernel deconvolution applied. The dirty image therefore
differs in absolute scale from the legacy `TransformerNUFFTPyNUFFT`
adjoint (which applies pynufft's internal IFFT and kernel
deconvolution). The structure of the dirty image is the same, and
the values match `TransformerDFT.image_from` exactly.

**Scale-sensitive callers**: `Interferometer.apply_sparse_operator`
consumes the dirty-image scale together with a precision operator; it
is currently incompatible with this class and raises
`NotImplementedError`. Use `TransformerDFT` or
`TransformerNUFFTPyNUFFT` if you need the sparse-operator path.
"""
n_y, n_x = self.real_space_mask.shape_native
n_modes = (n_x, n_y) # nufftax wants (n1, n2) = (N_x, N_y)

if xp.__name__.startswith("jax"):
import jax.numpy as jnp

x = jnp.asarray(self._x)
y = jnp.asarray(self._y)
shift_conj = jnp.asarray(np.conj(self._shift))
c = jnp.asarray(visibilities.array) * shift_conj
f = _nufftax.nufft2d1(x, y, c, n_modes, self.eps, +1)
image = jnp.real(f)[::-1, :]
else:
c = visibilities.array * np.conj(self._shift)
f = _nufftax.nufft2d1(self._x, self._y, c, n_modes, self.eps, +1)
image = np.array(np.asarray(f)[::-1, :].real)

if use_adjoint_scaling:
image = image * self.adjoint_scaling

return Array2D(values=image, mask=self.real_space_mask)

def transform_mapping_matrix(self, mapping_matrix, xp=np):
"""
Apply the forward NUFFT to each column of a mapping matrix.

Each column is scattered back to the native 2D image grid using the
mask's `slim_to_native_tuple`, then passed through `_forward_native`.
"""
n_uv = self.uv_wavelengths.shape[0]
n_src = mapping_matrix.shape[1]
slim_to_native = self.real_space_mask.slim_to_native_tuple
native_shape = self.real_space_mask.shape_native

if xp.__name__.startswith("jax"):
import jax.numpy as jnp

out = jnp.zeros((n_uv, n_src), dtype=jnp.complex128)
for k in range(n_src):
image_2d = jnp.zeros(native_shape, dtype=mapping_matrix.dtype)
image_2d = image_2d.at[slim_to_native].set(mapping_matrix[:, k])
vis = self._forward_native(image_2d, xp=xp)
out = out.at[:, k].set(vis)
return out

out = np.zeros((n_uv, n_src), dtype=np.complex128)
for k in range(n_src):
image_2d = np.zeros(native_shape, dtype=mapping_matrix.dtype)
image_2d[slim_to_native] = mapping_matrix[:, k]
out[:, k] = self._forward_native(image_2d, xp=xp)
return out
3 changes: 2 additions & 1 deletion autoarray/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@

from autoarray.operators.transformer import TransformerDFT
from autoarray.operators.transformer import TransformerNUFFT
from autoarray.operators.transformer import TransformerNUFFTPyNUFFT

Transformer = Union[TransformerDFT, TransformerNUFFT]
Transformer = Union[TransformerDFT, TransformerNUFFT, TransformerNUFFTPyNUFFT]


from autoarray.layout.region import Region1D
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ jax = ["autoconf[jax]"]
optional = [
"autoarray[jax]",
"numba",
"nufftax",
"pynufft",
"tensorflow-probability==0.25.0"
]
test = ["pytest"]
dev = ["pytest", "black", "numba", "pynufft==2022.2.2"]
dev = ["pytest", "black", "numba", "nufftax", "pynufft==2022.2.2"]

[tool.pytest.ini_options]
testpaths = ["test_autoarray"]
Expand Down
Loading
Loading