diff --git a/jax_profiling/jit/datacube/delaunay.py b/jax_profiling/jit/datacube/delaunay.py index 3d7918a..01ff97d 100644 --- a/jax_profiling/jit/datacube/delaunay.py +++ b/jax_profiling/jit/datacube/delaunay.py @@ -211,6 +211,11 @@ def jit_profile(func, label, *args, n_repeats=10): ) with timer.section("dataset_list_load"): + # apply_sparse_operator: precompute the NUFFT precision-matrix preload per + # channel so per-fit curvature assembly uses the FFT-based sparse path + # instead of dense DFT for every source pixel. Unblocked by PyAutoArray#316 + # (the Pmax > 1 extent-indexing fix); on Delaunay this was previously + # guarded with NotImplementedError. dataset_list = [ al.Interferometer.from_fits( data_path=dataset_path / "data.fits", @@ -222,7 +227,7 @@ def jit_profile(func, label, *args, n_repeats=10): # the JAX-traceable path is the goal, NUFFT (pynufft) is not yet # JIT-friendly. raise_error_dft_visibilities_limit=False, - ) + ).apply_sparse_operator(use_jax=True, show_progress=False) for _ in range(n_channels) ] diff --git a/jax_profiling/jit/interferometer/delaunay.py b/jax_profiling/jit/interferometer/delaunay.py index 34c1d9b..4fec68c 100644 --- a/jax_profiling/jit/interferometer/delaunay.py +++ b/jax_profiling/jit/interferometer/delaunay.py @@ -212,6 +212,13 @@ def jit_profile(func, label, *args, n_repeats=10): raise_error_dft_visibilities_limit=False, ) +with timer.section("apply_sparse_operator"): + # Precompute the NUFFT precision-matrix preload so per-fit curvature + # assembly uses the FFT-based sparse path instead of dense DFT for every + # source pixel. Unblocked by PyAutoArray#316 (the Pmax > 1 extent-indexing + # fix); on Delaunay this was previously guarded with NotImplementedError. + dataset = dataset.apply_sparse_operator(use_jax=True, show_progress=True) + n_visibilities = dataset.uv_wavelengths.shape[0] print(f" Total visibilities: {n_visibilities}")