@@ -678,28 +678,43 @@ def transform_mapping_matrix(self, mapping_matrix, xp=np):
678678 """
679679 Apply the forward NUFFT to each column of a mapping matrix.
680680
681- Each column is scattered back to the native 2D image grid using the
682- mask's `slim_to_native_tuple`, then passed through `_forward_native`.
681+ All columns are scattered into a single batched native-shape image
682+ of shape ``(n_src, N_y, N_x)`` and passed through nufft2d2 in one
683+ call (nufft2d2 supports batched ``f``). This avoids the
684+ per-column Python loop that, under ``jax.jit``, would unroll into
685+ ``n_src`` separate NUFFT invocations and blow up the JIT graph
686+ for pixelization-heavy fits (notably double-source-plane).
683687 """
684- n_uv = self .uv_wavelengths .shape [0 ]
685688 n_src = mapping_matrix .shape [1 ]
686- slim_to_native = self .real_space_mask .slim_to_native_tuple
687- native_shape = self .real_space_mask .shape_native
689+ rows , cols = self .real_space_mask .slim_to_native_tuple
690+ n_y , n_x = self .real_space_mask .shape_native
688691
689692 if xp .__name__ .startswith ("jax" ):
690693 import jax .numpy as jnp
691694
692- out = jnp .zeros ((n_uv , n_src ), dtype = jnp .complex128 )
693- for k in range (n_src ):
694- image_2d = jnp .zeros (native_shape , dtype = mapping_matrix .dtype )
695- image_2d = image_2d .at [slim_to_native ].set (mapping_matrix [:, k ])
696- vis = self ._forward_native (image_2d , xp = xp )
697- out = out .at [:, k ].set (vis )
698- return out
699-
700- out = np .zeros ((n_uv , n_src ), dtype = np .complex128 )
701- for k in range (n_src ):
702- image_2d = np .zeros (native_shape , dtype = mapping_matrix .dtype )
703- image_2d [slim_to_native ] = mapping_matrix [:, k ]
704- out [:, k ] = self ._forward_native (image_2d , xp = xp )
705- return out
695+ mm_T = jnp .asarray (mapping_matrix ).T .astype (jnp .complex128 )
696+ source_images = jnp .zeros ((n_src , n_y , n_x ), dtype = jnp .complex128 )
697+ source_images = source_images .at [
698+ jnp .arange (n_src )[:, None ],
699+ jnp .asarray (rows )[None , :],
700+ jnp .asarray (cols )[None , :],
701+ ].set (mm_T )
702+ flipped = source_images [:, ::- 1 , :]
703+ x = jnp .asarray (self ._x )
704+ y = jnp .asarray (self ._y )
705+ shift = jnp .asarray (self ._shift )
706+ # nufft2d2 returns shape (n_trans, M); transpose to (M, n_src).
707+ vis_batched = (
708+ _nufftax .nufft2d2 (x , y , flipped , self .eps , - 1 ) * shift [None , :]
709+ )
710+ return vis_batched .T
711+
712+ mm_T = np .asarray (mapping_matrix ).T .astype (np .complex128 )
713+ source_images = np .zeros ((n_src , n_y , n_x ), dtype = np .complex128 )
714+ source_images [np .arange (n_src )[:, None ], rows [None , :], cols [None , :]] = mm_T
715+ flipped = source_images [:, ::- 1 , :]
716+ vis_batched = (
717+ _nufftax .nufft2d2 (self ._x , self ._y , flipped , self .eps , - 1 )
718+ * self ._shift [None , :]
719+ )
720+ return np .array (np .asarray (vis_batched ).T )
0 commit comments