Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions autoarray/operators/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,28 +678,43 @@ def transform_mapping_matrix(self, mapping_matrix, xp=np):
"""
Apply the forward NUFFT to each column of a mapping matrix.

Each column is scattered back to the native 2D image grid using the
mask's `slim_to_native_tuple`, then passed through `_forward_native`.
All columns are scattered into a single batched native-shape image
of shape ``(n_src, N_y, N_x)`` and passed through nufft2d2 in one
call (nufft2d2 supports batched ``f``). This avoids the
per-column Python loop that, under ``jax.jit``, would unroll into
``n_src`` separate NUFFT invocations and blow up the JIT graph
for pixelization-heavy fits (notably double-source-plane).
"""
n_uv = self.uv_wavelengths.shape[0]
n_src = mapping_matrix.shape[1]
slim_to_native = self.real_space_mask.slim_to_native_tuple
native_shape = self.real_space_mask.shape_native
rows, cols = self.real_space_mask.slim_to_native_tuple
n_y, n_x = self.real_space_mask.shape_native

if xp.__name__.startswith("jax"):
import jax.numpy as jnp

out = jnp.zeros((n_uv, n_src), dtype=jnp.complex128)
for k in range(n_src):
image_2d = jnp.zeros(native_shape, dtype=mapping_matrix.dtype)
image_2d = image_2d.at[slim_to_native].set(mapping_matrix[:, k])
vis = self._forward_native(image_2d, xp=xp)
out = out.at[:, k].set(vis)
return out

out = np.zeros((n_uv, n_src), dtype=np.complex128)
for k in range(n_src):
image_2d = np.zeros(native_shape, dtype=mapping_matrix.dtype)
image_2d[slim_to_native] = mapping_matrix[:, k]
out[:, k] = self._forward_native(image_2d, xp=xp)
return out
mm_T = jnp.asarray(mapping_matrix).T.astype(jnp.complex128)
source_images = jnp.zeros((n_src, n_y, n_x), dtype=jnp.complex128)
source_images = source_images.at[
jnp.arange(n_src)[:, None],
jnp.asarray(rows)[None, :],
jnp.asarray(cols)[None, :],
].set(mm_T)
flipped = source_images[:, ::-1, :]
x = jnp.asarray(self._x)
y = jnp.asarray(self._y)
shift = jnp.asarray(self._shift)
# nufft2d2 returns shape (n_trans, M); transpose to (M, n_src).
vis_batched = (
_nufftax.nufft2d2(x, y, flipped, self.eps, -1) * shift[None, :]
)
return vis_batched.T

mm_T = np.asarray(mapping_matrix).T.astype(np.complex128)
source_images = np.zeros((n_src, n_y, n_x), dtype=np.complex128)
source_images[np.arange(n_src)[:, None], rows[None, :], cols[None, :]] = mm_T
flipped = source_images[:, ::-1, :]
vis_batched = (
_nufftax.nufft2d2(self._x, self._y, flipped, self.eps, -1)
* self._shift[None, :]
)
return np.array(np.asarray(vis_batched).T)
Loading