Skip to content

Commit bc00c11

Browse files
authored
Merge pull request #305 from PyAutoLabs/feature/nufftax-batched-mapping-matrix
perf: batched transform_mapping_matrix in TransformerNUFFT (single nufft2d2 call)
2 parents 46bb880 + cab85c9 commit bc00c11

1 file changed

Lines changed: 34 additions & 19 deletions

File tree

autoarray/operators/transformer.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -678,28 +678,43 @@ def transform_mapping_matrix(self, mapping_matrix, xp=np):
678678
"""
679679
Apply the forward NUFFT to each column of a mapping matrix.
680680
681-
Each column is scattered back to the native 2D image grid using the
682-
mask's `slim_to_native_tuple`, then passed through `_forward_native`.
681+
All columns are scattered into a single batched native-shape image
682+
of shape ``(n_src, N_y, N_x)`` and passed through nufft2d2 in one
683+
call (nufft2d2 supports batched ``f``). This avoids the
684+
per-column Python loop that, under ``jax.jit``, would unroll into
685+
``n_src`` separate NUFFT invocations and blow up the JIT graph
686+
for pixelization-heavy fits (notably double-source-plane).
683687
"""
684-
n_uv = self.uv_wavelengths.shape[0]
685688
n_src = mapping_matrix.shape[1]
686-
slim_to_native = self.real_space_mask.slim_to_native_tuple
687-
native_shape = self.real_space_mask.shape_native
689+
rows, cols = self.real_space_mask.slim_to_native_tuple
690+
n_y, n_x = self.real_space_mask.shape_native
688691

689692
if xp.__name__.startswith("jax"):
690693
import jax.numpy as jnp
691694

692-
out = jnp.zeros((n_uv, n_src), dtype=jnp.complex128)
693-
for k in range(n_src):
694-
image_2d = jnp.zeros(native_shape, dtype=mapping_matrix.dtype)
695-
image_2d = image_2d.at[slim_to_native].set(mapping_matrix[:, k])
696-
vis = self._forward_native(image_2d, xp=xp)
697-
out = out.at[:, k].set(vis)
698-
return out
699-
700-
out = np.zeros((n_uv, n_src), dtype=np.complex128)
701-
for k in range(n_src):
702-
image_2d = np.zeros(native_shape, dtype=mapping_matrix.dtype)
703-
image_2d[slim_to_native] = mapping_matrix[:, k]
704-
out[:, k] = self._forward_native(image_2d, xp=xp)
705-
return out
695+
mm_T = jnp.asarray(mapping_matrix).T.astype(jnp.complex128)
696+
source_images = jnp.zeros((n_src, n_y, n_x), dtype=jnp.complex128)
697+
source_images = source_images.at[
698+
jnp.arange(n_src)[:, None],
699+
jnp.asarray(rows)[None, :],
700+
jnp.asarray(cols)[None, :],
701+
].set(mm_T)
702+
flipped = source_images[:, ::-1, :]
703+
x = jnp.asarray(self._x)
704+
y = jnp.asarray(self._y)
705+
shift = jnp.asarray(self._shift)
706+
# nufft2d2 returns shape (n_trans, M); transpose to (M, n_src).
707+
vis_batched = (
708+
_nufftax.nufft2d2(x, y, flipped, self.eps, -1) * shift[None, :]
709+
)
710+
return vis_batched.T
711+
712+
mm_T = np.asarray(mapping_matrix).T.astype(np.complex128)
713+
source_images = np.zeros((n_src, n_y, n_x), dtype=np.complex128)
714+
source_images[np.arange(n_src)[:, None], rows[None, :], cols[None, :]] = mm_T
715+
flipped = source_images[:, ::-1, :]
716+
vis_batched = (
717+
_nufftax.nufft2d2(self._x, self._y, flipped, self.eps, -1)
718+
* self._shift[None, :]
719+
)
720+
return np.array(np.asarray(vis_batched).T)

0 commit comments

Comments
 (0)