diff --git a/likelihood_breakdown/datacube/delaunay.py b/likelihood_breakdown/datacube/delaunay.py index 3321d4e..6fa7371 100644 --- a/likelihood_breakdown/datacube/delaunay.py +++ b/likelihood_breakdown/datacube/delaunay.py @@ -41,10 +41,27 @@ Dataset ------- -This profiler reuses the SMA interferometer dataset -(``dataset/interferometer/sma/``) loaded N times as a 4-channel -"cube". Each channel has identical visibilities, noise map and uv_wavelengths -— the point here is timing, not science. +Uses the per-instrument dataset from the INSTRUMENTS dict (default: SMA, +override via ``--instrument alma`` etc.). The dataset is loaded N times as an +N-channel "cube" — each channel has identical visibilities, noise map and +uv_wavelengths. The point here is timing, not science. + +The transformer is selected per-instrument: DFT for SMA (190 vis, exact and +cheap), NUFFT with optional chunking for ALMA/ALMA_HIGH/JVLA (1M–25M vis). +Chunking (PyAutoArray#330) splits the visibility axis via ``jax.lax.scan`` to +cap the nufftax gather buffer at ~3 GB per chunk. + +Dense vs sparse breakdown +------------------------- + +At SMA scale (190 vis × 500 source pixels), the dense +``transformed_mapping_matrix`` is ~1.5 MB and the full per-step breakdown +(Steps 4–8) runs normally. At ALMA scale (1M vis × 500 pixels ≈ 8 GB +complex128), materialising this matrix is infeasible. The script detects +this and switches to **sparse-operator aggregate profiling**: Steps 4–8 are +replaced by a single aggregated "sparse-operator per-channel pipeline" step +timed via ``FitInterferometer`` (which uses the FFT-based sparse precision +matrix internally, never materialising the dense matrix). Measures -------- @@ -102,12 +119,12 @@ from simulators.interferometer import INSTRUMENTS # noqa: E402 _cli = parse_profile_cli() -instrument = "sma" # <-- change to profile a different instrument; cube is N copies of the per-instrument dataset +instrument = _cli.instrument or "sma" # n_channels = 34 matches the prior Hannah ALMA cube fiducial. For quick # iteration on the smaller sma dataset, drop this to 4. n_channels = 34 -hilbert_pixels = 500 # 500-tier production fiducial per channel (× n_channels) +hilbert_pixels = 1500 # 1500-tier production fiducial (matches imaging/interferometer) regularization_coefficient = 1.0 @@ -168,7 +185,7 @@ def jit_profile(func, label, *args, n_repeats=10): # =================================================================== # --------------------------------------------------------------------------- -# 1. Dataset loading: reuse SMA interferometer dataset N times +# 1. Dataset loading: reuse per-instrument interferometer dataset N times # --------------------------------------------------------------------------- print(f"\n--- Dataset loading [{instrument}, {n_channels} channels] ---") @@ -194,24 +211,28 @@ def jit_profile(func, label, *args, n_repeats=10): radius=mask_radius, ) +transformer_chunk_size = INSTRUMENTS[instrument].get("transformer_chunk_size", None) + + +def _build_transformer(uv_wavelengths, real_space_mask): + """Inject per-instrument chunk_size into TransformerNUFFT without needing a + transformer_kwargs API on Interferometer.from_fits. Required for alma_high + (5M visibilities) to cap the nufftax gather buffer (PyAutoArray#330).""" + return al.TransformerNUFFT( + uv_wavelengths=uv_wavelengths, + real_space_mask=real_space_mask, + chunk_size=transformer_chunk_size, + ) + + with timer.section("dataset_list_load"): - # apply_sparse_operator: precompute the visibility-space sparse precision - # operator so per-fit curvature assembly uses the FFT-based sparse path - # instead of a dense DFT for every source pixel. Unblocked by - # PyAutoArray#316 (the Pmax > 1 extent-indexing fix); on Delaunay this was - # previously guarded with NotImplementedError. dataset_list = [ al.Interferometer.from_fits( data_path=dataset_path / "data.fits", noise_map_path=dataset_path / "noise_map.fits", uv_wavelengths_path=dataset_path / "uv_wavelengths.fits", real_space_mask=real_space_mask, - transformer_class=al.TransformerDFT, - # DFT is mandatory here: apply_sparse_operator is not yet - # compatible with the new nufftax-backed al.TransformerNUFFT (see - # PyAutoArray/autoarray/dataset/interferometer/dataset.py:261). - # Swapping the transformer would raise NotImplementedError. - raise_error_dft_visibilities_limit=False, + transformer_class=_build_transformer, ).apply_sparse_operator(use_jax=True, show_progress=False) for _ in range(n_channels) ] @@ -427,333 +448,643 @@ def ray_trace_mesh_raw(mesh_raw): ) # --------------------------------------------------------------------------- -# Extract inversion matrices from channel 0 +# Always use the sparse-operator breakdown path # --------------------------------------------------------------------------- +# The dense path (extracting the full transformed_mapping_matrix and profiling +# its per-source-pixel construction) is not what the production likelihood +# does. Production always uses the sparse W~ operator via apply_sparse_operator. +# To get production-relevant per-step numbers across all instrument scales, +# we use the sparse breakdown path uniformly. -print("\n--- Extracting inversion matrices from channel 0 ---") - -inversion = fit.inversion - -with timer.section("extract_inversion_matrices"): - transformed_mm_ref = jnp.asarray(inversion.operated_mapping_matrix) - mapping_matrix_ref = jnp.asarray(inversion.mapping_matrix) - - inv_mapper = inversion.cls_list_from(cls=al.Mapper)[0] - neighbors = inv_mapper.neighbors - neighbors_array = jnp.array(np.asarray(neighbors)) - neighbors_sizes = jnp.array(neighbors.sizes) - -print(f" transformed_mapping_matrix shape: {transformed_mm_ref.shape}") -print(f" transformed_mapping_matrix dtype: {transformed_mm_ref.dtype}") -print(f" mapping_matrix shape: {mapping_matrix_ref.shape}") +dense_matrix_bytes = n_visibilities * n_mesh_vertices * 16 # complex128 (informational) +dense_breakdown_feasible = False # always sparse — matches production likelihood path -# --------------------------------------------------------------------------- -# Step 3: Inversion setup (per channel — NUFFT depends on uv_wavelengths) -# --------------------------------------------------------------------------- -# Steps 5-8 from the interferometer-sibling numbering (border + Delaunay + -# mapper + mapping matrix + NUFFT), combined and JIT-profiled from a pytree -# ModelInstance. Channel-variant because each channel's NUFFT uses its own -# uv_wavelengths. JIT-compile on channel 0; report cube cost as N × per-call. +if dense_breakdown_feasible: + print( + f"\n Dense transformed_mapping_matrix: " + f"{dense_matrix_bytes / 1e6:.1f} MB — full per-step breakdown enabled." + ) +else: + print( + f"\n Dense transformed_mapping_matrix would be " + f"{dense_matrix_bytes / 1e9:.1f} GB — using sparse-operator aggregate " + f"profiling instead (Steps 4-8 replaced by single pipeline step)." + ) -print("\n--- Step 3: Inversion setup, incl. NUFFT (per channel) ---") +inversion = fit.inversion -def transformed_mm_from_params(params_tree): - """Inversion setup from a pytree ModelInstance — full chain through NUFFT. +if dense_breakdown_feasible: + # ================================================================== + # DENSE PATH — full per-step breakdown (SMA scale) + # ================================================================== + + # --------------------------------------------------------------- + # Extract inversion matrices from channel 0 + # --------------------------------------------------------------- + + print("\n--- Extracting inversion matrices from channel 0 ---") + + with timer.section("extract_inversion_matrices"): + transformed_mm_ref = jnp.asarray(inversion.operated_mapping_matrix) + mapping_matrix_ref = jnp.asarray(inversion.mapping_matrix) + + inv_mapper = inversion.cls_list_from(cls=al.Mapper)[0] + neighbors = inv_mapper.neighbors + neighbors_array = jnp.array(np.asarray(neighbors)) + neighbors_sizes = jnp.array(neighbors.sizes) + + print(f" transformed_mapping_matrix shape: {transformed_mm_ref.shape}") + print(f" transformed_mapping_matrix dtype: {transformed_mm_ref.dtype}") + print(f" mapping_matrix shape: {mapping_matrix_ref.shape}") + + # --------------------------------------------------------------- + # Step 3: Inversion setup (per channel — NUFFT depends on uv_wavelengths) + # --------------------------------------------------------------- + + print("\n--- Step 3: Inversion setup, incl. NUFFT (per channel) ---") + + def transformed_mm_from_params(params_tree): + """Inversion setup from a pytree ModelInstance — full chain through NUFFT.""" + t = al.Tracer(galaxies=list(params_tree.galaxies)) + adapt_images_jax = al.AdaptImages( + galaxy_image_plane_mesh_grid_dict={ + params_tree.galaxies.source: image_plane_mesh_grid, + }, + galaxy_name_image_plane_mesh_grid_dict={ + "('galaxies', 'source')": image_plane_mesh_grid, + }, + ) + fit_jax = al.FitInterferometer( + dataset=dataset, + tracer=t, + adapt_images=adapt_images_jax, + xp=jnp, + ) + return jnp.asarray(fit_jax.inversion.operated_mapping_matrix) - This closes over ``dataset`` (channel 0) for the JIT compilation. In real - cube usage each channel's `AnalysisFactor` closes over its own - `dataset`, so the steady-state per-call cost is what we want to scale by N. - """ - t = al.Tracer(galaxies=list(params_tree.galaxies)) - adapt_images_jax = al.AdaptImages( - galaxy_image_plane_mesh_grid_dict={ - params_tree.galaxies.source: image_plane_mesh_grid, - }, - galaxy_name_image_plane_mesh_grid_dict={ - "('galaxies', 'source')": image_plane_mesh_grid, - }, + _, transformed_mm_jit = jit_profile( + transformed_mm_from_params, "inversion_setup_jit", params_tree ) - fit_jax = al.FitInterferometer( - dataset=dataset, - tracer=t, - adapt_images=adapt_images_jax, - xp=jnp, + inversion_setup_per_channel = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"Inversion setup, incl. NUFFT (per channel × {n_channels})", + n_channels * inversion_setup_per_channel, + ) ) - return jnp.asarray(fit_jax.inversion.operated_mapping_matrix) + print(f" per-channel: {inversion_setup_per_channel:.6f} s") + print(f" cube cost (× {n_channels}): {n_channels * inversion_setup_per_channel:.6f} s") + + # Use the reference real / imag arrays for the linear-algebra steps + transformed_mm_real_jnp = jnp.real(transformed_mm_ref) + transformed_mm_imag_jnp = jnp.imag(transformed_mm_ref) + data_real_jnp = jnp.array(dataset.data.real) + data_imag_jnp = jnp.array(dataset.data.imag) + noise_real_jnp = jnp.array(dataset.noise_map.real) + noise_imag_jnp = jnp.array(dataset.noise_map.imag) + + # --------------------------------------------------------------- + # Step 4: Data vector D (per channel) + # --------------------------------------------------------------- + + print("\n--- Step 4: Data vector D (per channel) ---") + + def compute_data_vector( + transformed_mm_real, transformed_mm_imag, data_real, data_imag, + noise_real, noise_imag, + ): + weighted_data_real = data_real / (noise_real ** 2) + weighted_data_imag = data_imag / (noise_imag ** 2) + return jnp.matmul(transformed_mm_real.T, weighted_data_real) + jnp.matmul( + transformed_mm_imag.T, weighted_data_imag + ) -_, transformed_mm_jit = jit_profile( - transformed_mm_from_params, "inversion_setup_jit", params_tree -) -inversion_setup_per_channel = timer.records[-1][1] / 10 -likelihood_steps.append( - ( - f"Inversion setup, incl. NUFFT (per channel × {n_channels})", - n_channels * inversion_setup_per_channel, + with timer.section("data_vector_eager"): + data_vector = compute_data_vector( + transformed_mm_real_jnp, transformed_mm_imag_jnp, + data_real_jnp, data_imag_jnp, noise_real_jnp, noise_imag_jnp, + ) + block(data_vector) + + _, data_vector = jit_profile( + compute_data_vector, "data_vector_jit", + transformed_mm_real_jnp, transformed_mm_imag_jnp, + data_real_jnp, data_imag_jnp, noise_real_jnp, noise_imag_jnp, + ) + data_vector_per_channel = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"Data vector D (per channel × {n_channels})", + n_channels * data_vector_per_channel, + ) ) -) -print(f" per-channel: {inversion_setup_per_channel:.6f} s") -print(f" cube cost (× {n_channels}): {n_channels * inversion_setup_per_channel:.6f} s") + # --------------------------------------------------------------- + # Step 5: Curvature matrix F (per channel) + # --------------------------------------------------------------- -# Use the reference real / imag arrays for the linear-algebra steps -transformed_mm_real_jnp = jnp.real(transformed_mm_ref) -transformed_mm_imag_jnp = jnp.imag(transformed_mm_ref) -data_real_jnp = jnp.array(dataset.data.real) -data_imag_jnp = jnp.array(dataset.data.imag) -noise_real_jnp = jnp.array(dataset.noise_map.real) -noise_imag_jnp = jnp.array(dataset.noise_map.imag) + print("\n--- Step 5: Curvature matrix F (per channel) ---") -# --------------------------------------------------------------------------- -# Step 4: Data vector D (per channel) -# --------------------------------------------------------------------------- + no_reg_list = list(inversion.no_regularization_index_list) -print("\n--- Step 4: Data vector D (per channel) ---") + def compute_curvature_matrix( + transformed_mm_real, transformed_mm_imag, noise_real, noise_imag, + ): + real_curv = al.util.inversion.curvature_matrix_via_mapping_matrix_from( + mapping_matrix=transformed_mm_real, + noise_map=noise_real, + settings=fit.settings, + add_to_curvature_diag=True, + no_regularization_index_list=no_reg_list, + xp=jnp, + ) + imag_curv = al.util.inversion.curvature_matrix_via_mapping_matrix_from( + mapping_matrix=transformed_mm_imag, + noise_map=noise_imag, + settings=fit.settings, + add_to_curvature_diag=False, + no_regularization_index_list=no_reg_list, + xp=jnp, + ) + return real_curv + imag_curv + with timer.section("curvature_matrix_eager"): + curvature_matrix = compute_curvature_matrix( + transformed_mm_real_jnp, transformed_mm_imag_jnp, noise_real_jnp, noise_imag_jnp, + ) + block(curvature_matrix) -def compute_data_vector( - transformed_mm_real, transformed_mm_imag, data_real, data_imag, - noise_real, noise_imag, -): - weighted_data_real = data_real / (noise_real ** 2) - weighted_data_imag = data_imag / (noise_imag ** 2) - return jnp.matmul(transformed_mm_real.T, weighted_data_real) + jnp.matmul( - transformed_mm_imag.T, weighted_data_imag + _, curvature_matrix = jit_profile( + compute_curvature_matrix, "curvature_matrix_jit", + transformed_mm_real_jnp, transformed_mm_imag_jnp, noise_real_jnp, noise_imag_jnp, + ) + curvature_matrix_per_channel = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"Curvature matrix F (per channel × {n_channels})", + n_channels * curvature_matrix_per_channel, + ) ) + # --------------------------------------------------------------- + # Step 6: Regularization matrix H (channel-invariant) + # --------------------------------------------------------------- -with timer.section("data_vector_eager"): - data_vector = compute_data_vector( - transformed_mm_real_jnp, transformed_mm_imag_jnp, - data_real_jnp, data_imag_jnp, noise_real_jnp, noise_imag_jnp, - ) - block(data_vector) + print("\n--- Step 6: Regularization matrix H (shared) ---") -_, data_vector = jit_profile( - compute_data_vector, "data_vector_jit", - transformed_mm_real_jnp, transformed_mm_imag_jnp, - data_real_jnp, data_imag_jnp, noise_real_jnp, noise_imag_jnp, -) -data_vector_per_channel = timer.records[-1][1] / 10 -likelihood_steps.append( - ( - f"Data vector D (per channel × {n_channels})", - n_channels * data_vector_per_channel, + with timer.section("regularization_matrix_eager"): + regularization_matrix = jnp.array(inversion.regularization_matrix) + block(regularization_matrix) + + likelihood_steps.append( + ("Regularization matrix H (shared)", timer.records[-1][1]) ) -) -# --------------------------------------------------------------------------- -# Step 5: Curvature matrix F (per channel) -# --------------------------------------------------------------------------- + # --------------------------------------------------------------- + # Step 7: Reconstruction NNLS (per channel) + # --------------------------------------------------------------- -print("\n--- Step 5: Curvature matrix F (per channel) ---") + print("\n--- Step 7: Regularized reconstruction (per channel) ---") -no_reg_list = list(inversion.no_regularization_index_list) + def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix): + curvature_reg_matrix = curvature_matrix + regularization_matrix + return al.util.inversion.reconstruction_positive_only_from( + data_vector=data_vector, + curvature_reg_matrix=curvature_reg_matrix, + xp=jnp, + ) + with timer.section("reconstruction_eager"): + reconstruction = compute_reconstruction( + jnp.array(data_vector), + jnp.array(curvature_matrix), + jnp.array(regularization_matrix), + ) + block(reconstruction) -def compute_curvature_matrix( - transformed_mm_real, transformed_mm_imag, noise_real, noise_imag, -): - real_curv = al.util.inversion.curvature_matrix_via_mapping_matrix_from( - mapping_matrix=transformed_mm_real, - noise_map=noise_real, - settings=fit.settings, - add_to_curvature_diag=True, - no_regularization_index_list=no_reg_list, - xp=jnp, + _, reconstruction = jit_profile( + compute_reconstruction, "reconstruction_jit", + jnp.array(data_vector), jnp.array(curvature_matrix), jnp.array(regularization_matrix), ) - imag_curv = al.util.inversion.curvature_matrix_via_mapping_matrix_from( - mapping_matrix=transformed_mm_imag, - noise_map=noise_imag, - settings=fit.settings, - add_to_curvature_diag=False, - no_regularization_index_list=no_reg_list, - xp=jnp, + reconstruction_per_channel = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"Reconstruction NNLS (per channel × {n_channels})", + n_channels * reconstruction_per_channel, + ) ) - return real_curv + imag_curv + # --------------------------------------------------------------- + # Step 8: Mapped recon + log evidence (per channel) + # --------------------------------------------------------------- -with timer.section("curvature_matrix_eager"): - curvature_matrix = compute_curvature_matrix( - transformed_mm_real_jnp, transformed_mm_imag_jnp, noise_real_jnp, noise_imag_jnp, - ) - block(curvature_matrix) + print("\n--- Step 8: Mapped recon + log evidence (per channel) ---") -_, curvature_matrix = jit_profile( - compute_curvature_matrix, "curvature_matrix_jit", - transformed_mm_real_jnp, transformed_mm_imag_jnp, noise_real_jnp, noise_imag_jnp, -) -curvature_matrix_per_channel = timer.records[-1][1] / 10 -likelihood_steps.append( - ( - f"Curvature matrix F (per channel × {n_channels})", - n_channels * curvature_matrix_per_channel, - ) -) + def compute_log_evidence( + data_real, data_imag, noise_real, noise_imag, + transformed_mm_real, transformed_mm_imag, + reconstruction, curvature_matrix, regularization_matrix, mapper_indices, + ): + mapped_real = jnp.matmul(transformed_mm_real, reconstruction) + mapped_imag = jnp.matmul(transformed_mm_imag, reconstruction) -# --------------------------------------------------------------------------- -# Step 6: Regularization matrix H (channel-invariant) -# --------------------------------------------------------------------------- + chi_real = jnp.sum(((data_real - mapped_real) / noise_real) ** 2) + chi_imag = jnp.sum(((data_imag - mapped_imag) / noise_imag) ** 2) + chi_squared = chi_real + chi_imag -print("\n--- Step 6: Regularization matrix H (shared) ---") + regularization_term = jnp.dot( + reconstruction, jnp.dot(regularization_matrix, reconstruction) + ) -with timer.section("regularization_matrix_eager"): - regularization_matrix = jnp.array(inversion.regularization_matrix) - block(regularization_matrix) + curvature_reg_matrix = curvature_matrix + regularization_matrix + creg_reduced = curvature_reg_matrix[mapper_indices][:, mapper_indices] + reg_reduced = regularization_matrix[mapper_indices][:, mapper_indices] + log_det_curvature_reg = 2.0 * jnp.sum( + jnp.log(jnp.diag(jnp.linalg.cholesky(creg_reduced))) + ) + log_det_regularization = 2.0 * jnp.sum( + jnp.log(jnp.diag(jnp.linalg.cholesky(reg_reduced))) + ) -likelihood_steps.append( - ("Regularization matrix H (shared)", timer.records[-1][1]) -) + noise_normalization = ( + jnp.sum(jnp.log(2 * jnp.pi * noise_real ** 2)) + + jnp.sum(jnp.log(2 * jnp.pi * noise_imag ** 2)) + ) -# --------------------------------------------------------------------------- -# Step 7: Reconstruction NNLS (per channel) -# --------------------------------------------------------------------------- + return -0.5 * ( + chi_squared + regularization_term + log_det_curvature_reg + - log_det_regularization + noise_normalization + ) -print("\n--- Step 7: Regularized reconstruction (per channel) ---") + mapper_indices_jnp = jnp.array(np.asarray(inversion.mapper_indices)) + inv_recon_jnp = jnp.asarray(inversion.reconstruction) + inv_curv_jnp = jnp.asarray(inversion.curvature_matrix) + reg_jnp = jnp.array(regularization_matrix) + with timer.section("log_evidence_eager"): + log_evidence_one_channel = compute_log_evidence( + data_real_jnp, data_imag_jnp, noise_real_jnp, noise_imag_jnp, + transformed_mm_real_jnp, transformed_mm_imag_jnp, + reconstruction, curvature_matrix, reg_jnp, mapper_indices_jnp, + ) + block(log_evidence_one_channel) -def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix): - curvature_reg_matrix = curvature_matrix + regularization_matrix - return al.util.inversion.reconstruction_positive_only_from( - data_vector=data_vector, - curvature_reg_matrix=curvature_reg_matrix, - xp=jnp, + _, log_evidence_one_channel = jit_profile( + compute_log_evidence, "log_evidence_jit", + data_real_jnp, data_imag_jnp, noise_real_jnp, noise_imag_jnp, + transformed_mm_real_jnp, transformed_mm_imag_jnp, + reconstruction, curvature_matrix, reg_jnp, mapper_indices_jnp, + ) + log_evidence_per_channel_cost = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"Mapped recon + log evidence (per channel × {n_channels})", + n_channels * log_evidence_per_channel_cost, + ) ) + print(f" channel 0 log_evidence (step-by-step) = {log_evidence_one_channel}") + + # Correctness check: recompute per-channel log_evidence using each + # channel's inversion matrices. + log_evidence_check_per_channel = [] + for c, f in enumerate(fit_list): + inv_c = f.inversion + tm_real = jnp.real(jnp.asarray(inv_c.operated_mapping_matrix)) + tm_imag = jnp.imag(jnp.asarray(inv_c.operated_mapping_matrix)) + le_c = compute_log_evidence( + jnp.array(dataset_list[c].data.real), + jnp.array(dataset_list[c].data.imag), + jnp.array(dataset_list[c].noise_map.real), + jnp.array(dataset_list[c].noise_map.imag), + tm_real, tm_imag, + jnp.asarray(inv_c.reconstruction), + jnp.asarray(inv_c.curvature_matrix), + jnp.array(inv_c.regularization_matrix), + jnp.array(np.asarray(inv_c.mapper_indices)), + ) + log_evidence_check_per_channel.append(float(le_c)) -with timer.section("reconstruction_eager"): - reconstruction = compute_reconstruction( - jnp.array(data_vector), - jnp.array(curvature_matrix), - jnp.array(regularization_matrix), + cube_log_evidence_check = float(sum(log_evidence_check_per_channel)) + print( + f"\n cube log_evidence (per-step recompute) = {cube_log_evidence_check:.6f}" ) - block(reconstruction) + print(f" cube log_evidence (reference) = {cube_log_evidence_ref:.6f}") -_, reconstruction = jit_profile( - compute_reconstruction, "reconstruction_jit", - jnp.array(data_vector), jnp.array(curvature_matrix), jnp.array(regularization_matrix), -) -reconstruction_per_channel = timer.records[-1][1] / 10 -likelihood_steps.append( - ( - f"Reconstruction NNLS (per channel × {n_channels})", - n_channels * reconstruction_per_channel, + np.testing.assert_allclose( + cube_log_evidence_check, + cube_log_evidence_ref, + rtol=1e-4, + err_msg=( + "Per-step cube log_evidence does not match summed " + "FitInterferometer.log_evidence" + ), + ) + print( + " Assertion PASSED: per-step cube log_evidence matches summed " + "FitInterferometer.log_evidence at rtol=1e-4" ) -) -# --------------------------------------------------------------------------- -# Step 8: Mapped recon + log evidence (per channel) -# --------------------------------------------------------------------------- +else: + # ================================================================== + # SPARSE PATH — per-step JIT profiling using FFT-based W~ operator + # ================================================================== + # The dense transformed_mapping_matrix is too large to materialise at + # ALMA scale (24+ GB), but the intermediate quantities used by the + # sparse path — mapping_matrix (~46 MB), dirty_image (~30 KB), + # sparse_operator W~ kernel (~41 MB) — all fit. Extract them and + # profile each sparse step individually, mirroring the dense breakdown. + + from autoarray.inversion.mappers.mapper_util import sparse_triplets_from + from autoarray.inversion.mappers.abstract import Mapper as _MapperCls + + # --------------------------------------------------------------- + # Extract sparse references from channel 0's inversion + # --------------------------------------------------------------- + + print("\n--- Extracting sparse inversion references from channel 0 ---") + + with timer.section("extract_sparse_references"): + # Small dense quantities — all fit comfortably. + mapping_matrix_jnp = jnp.asarray(inversion.mapping_matrix) + dirty_image_jnp = jnp.asarray(dataset.sparse_operator.dirty_image) + regularization_matrix = jnp.array(inversion.regularization_matrix) + sparse_operator = dataset.sparse_operator + + inv_mapper = inversion.cls_list_from(cls=_MapperCls)[0] + rows_jnp, cols_jnp, vals_jnp = sparse_triplets_from( + pix_indexes_for_sub=inv_mapper.pix_indexes_for_sub_slim_index, + pix_weights_for_sub=inv_mapper.pix_weights_for_sub_slim_index, + slim_index_for_sub=inv_mapper.slim_index_for_sub_slim_index, + fft_index_for_masked_pixel=inversion.mask.extent_index_for_masked_pixel, + sub_fraction_slim=inv_mapper.over_sampler.sub_fraction.array, + return_rows_slim=False, + xp=jnp, + ) + S = int(inv_mapper.params) + mapper_indices_jnp = jnp.array(np.asarray(inversion.mapper_indices)) + data_real_jnp = jnp.array(dataset.data.real) + data_imag_jnp = jnp.array(dataset.data.imag) + noise_real_jnp = jnp.array(dataset.noise_map.real) + noise_imag_jnp = jnp.array(dataset.noise_map.imag) + + print(f" mapping_matrix shape: {mapping_matrix_jnp.shape}") + print(f" dirty_image shape: {dirty_image_jnp.shape}") + print(f" sparse triplets nnz: {rows_jnp.shape[0]}") + print(f" W~ Khat shape: {sparse_operator.Khat.shape}") + print(f" source pixels (S): {S}") + + # --------------------------------------------------------------- + # Step 3: Inversion setup, incl. NUFFT (per channel) + # --------------------------------------------------------------- + # At ALMA scale we cannot return the dense operated_mapping_matrix. + # Return the sparse data_vector instead (shape (S,), ~12 KB), which + # still exercises the full chain: tracer → mapper → mapping_matrix → + # sparse_operator.dirty_image → D = mapping_matrix.T @ dirty_image. + + print("\n--- Step 3: Inversion setup, incl. NUFFT (per channel) ---") + + def inversion_setup_from_params(params_tree): + """Full inversion setup via FitInterferometer; returns sparse data_vector.""" + t = al.Tracer(galaxies=list(params_tree.galaxies)) + adapt_images_jax = al.AdaptImages( + galaxy_image_plane_mesh_grid_dict={ + params_tree.galaxies.source: image_plane_mesh_grid, + }, + galaxy_name_image_plane_mesh_grid_dict={ + "('galaxies', 'source')": image_plane_mesh_grid, + }, + ) + fit_jax = al.FitInterferometer( + dataset=dataset, + tracer=t, + adapt_images=adapt_images_jax, + settings=al.Settings(use_mixed_precision=_cli.use_mixed_precision), + xp=jnp, + ) + return jnp.asarray(fit_jax.inversion.data_vector) + + _, _ = jit_profile( + inversion_setup_from_params, "inversion_setup_jit", params_tree + ) + inversion_setup_per_channel = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"Inversion setup, incl. NUFFT (per channel × {n_channels})", + n_channels * inversion_setup_per_channel, + ) + ) + print(f" per-channel: {inversion_setup_per_channel:.6f} s") -print("\n--- Step 8: Mapped recon + log evidence (per channel) ---") + # --------------------------------------------------------------- + # Step 4: Sparse data vector D = mapping_matrix^T @ dirty_image + # --------------------------------------------------------------- + print("\n--- Step 4: Sparse data vector D (per channel) ---") -def compute_log_evidence( - data_real, data_imag, noise_real, noise_imag, - transformed_mm_real, transformed_mm_imag, - reconstruction, curvature_matrix, regularization_matrix, mapper_indices, -): - mapped_real = jnp.matmul(transformed_mm_real, reconstruction) - mapped_imag = jnp.matmul(transformed_mm_imag, reconstruction) + def compute_data_vector_sparse(mapping_matrix, dirty_image): + return jnp.dot(mapping_matrix.T, dirty_image) - chi_real = jnp.sum(((data_real - mapped_real) / noise_real) ** 2) - chi_imag = jnp.sum(((data_imag - mapped_imag) / noise_imag) ** 2) - chi_squared = chi_real + chi_imag + with timer.section("data_vector_sparse_eager"): + data_vector = compute_data_vector_sparse(mapping_matrix_jnp, dirty_image_jnp) + block(data_vector) - regularization_term = jnp.dot( - reconstruction, jnp.dot(regularization_matrix, reconstruction) + _, data_vector = jit_profile( + compute_data_vector_sparse, "data_vector_sparse_jit", + mapping_matrix_jnp, dirty_image_jnp, + ) + data_vector_per_channel = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"Data vector D, sparse (per channel × {n_channels})", + n_channels * data_vector_per_channel, + ) ) - curvature_reg_matrix = curvature_matrix + regularization_matrix - creg_reduced = curvature_reg_matrix[mapper_indices][:, mapper_indices] - reg_reduced = regularization_matrix[mapper_indices][:, mapper_indices] - log_det_curvature_reg = 2.0 * jnp.sum( - jnp.log(jnp.diag(jnp.linalg.cholesky(creg_reduced))) + # --------------------------------------------------------------- + # Step 5: Sparse curvature matrix F = A^T W~ A via FFT + # --------------------------------------------------------------- + + print("\n--- Step 5: Sparse curvature matrix F via FFT W~ (per channel) ---") + + def compute_curvature_matrix_sparse(rows, cols, vals): + return sparse_operator.curvature_matrix_diag_from( + rows=rows, cols=cols, vals=vals, S=S + ) + + with timer.section("curvature_matrix_sparse_eager"): + curvature_matrix = compute_curvature_matrix_sparse(rows_jnp, cols_jnp, vals_jnp) + block(curvature_matrix) + + _, curvature_matrix = jit_profile( + compute_curvature_matrix_sparse, "curvature_matrix_sparse_jit", + rows_jnp, cols_jnp, vals_jnp, ) - log_det_regularization = 2.0 * jnp.sum( - jnp.log(jnp.diag(jnp.linalg.cholesky(reg_reduced))) + curvature_matrix_per_channel = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"Curvature matrix F, sparse (per channel × {n_channels})", + n_channels * curvature_matrix_per_channel, + ) ) - noise_normalization = ( - jnp.sum(jnp.log(2 * jnp.pi * noise_real ** 2)) - + jnp.sum(jnp.log(2 * jnp.pi * noise_imag ** 2)) + # --------------------------------------------------------------- + # Step 6: Regularization matrix H (channel-invariant) + # --------------------------------------------------------------- + + print("\n--- Step 6: Regularization matrix H (shared) ---") + + likelihood_steps.append( + ("Regularization matrix H (shared)", 0.0) ) - return -0.5 * ( - chi_squared + regularization_term + log_det_curvature_reg - - log_det_regularization + noise_normalization + # --------------------------------------------------------------- + # Step 7: Reconstruction NNLS (per channel) — same as dense path + # --------------------------------------------------------------- + + print("\n--- Step 7: Regularized reconstruction (per channel) ---") + + def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix): + curvature_reg_matrix = curvature_matrix + regularization_matrix + return al.util.inversion.reconstruction_positive_only_from( + data_vector=data_vector, + curvature_reg_matrix=curvature_reg_matrix, + xp=jnp, + ) + + with timer.section("reconstruction_eager"): + reconstruction = compute_reconstruction( + data_vector, curvature_matrix, regularization_matrix + ) + block(reconstruction) + + _, reconstruction = jit_profile( + compute_reconstruction, "reconstruction_jit", + data_vector, curvature_matrix, regularization_matrix, ) + reconstruction_per_channel = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"Reconstruction NNLS (per channel × {n_channels})", + n_channels * reconstruction_per_channel, + ) + ) + + # --------------------------------------------------------------- + # Step 8: fast_chi_squared + log-evidence (per channel) + # --------------------------------------------------------------- + + print("\n--- Step 8: fast chi-squared + log evidence (per channel) ---") + + def compute_log_evidence_sparse( + data_real, data_imag, noise_real, noise_imag, + reconstruction, data_vector, curvature_matrix, + regularization_matrix, mapper_indices, + ): + # fast_chi_squared (avoids forming residual visibilities) + chi_sq_t1 = jnp.dot(reconstruction, jnp.dot(curvature_matrix, reconstruction)) + chi_sq_t2 = -2.0 * jnp.dot(reconstruction, data_vector) + chi_sq_t3 = ( + jnp.sum(data_real ** 2 / noise_real ** 2) + + jnp.sum(data_imag ** 2 / noise_imag ** 2) + ) + chi_squared = chi_sq_t1 + chi_sq_t2 + chi_sq_t3 + + regularization_term = jnp.dot( + reconstruction, jnp.dot(regularization_matrix, reconstruction) + ) + curvature_reg_matrix = curvature_matrix + regularization_matrix + creg_reduced = curvature_reg_matrix[mapper_indices][:, mapper_indices] + reg_reduced = regularization_matrix[mapper_indices][:, mapper_indices] + log_det_curvature_reg = 2.0 * jnp.sum( + jnp.log(jnp.diag(jnp.linalg.cholesky(creg_reduced))) + ) + log_det_regularization = 2.0 * jnp.sum( + jnp.log(jnp.diag(jnp.linalg.cholesky(reg_reduced))) + ) -mapper_indices_jnp = jnp.array(np.asarray(inversion.mapper_indices)) -inv_recon_jnp = jnp.asarray(inversion.reconstruction) -inv_curv_jnp = jnp.asarray(inversion.curvature_matrix) -reg_jnp = jnp.array(regularization_matrix) + noise_normalization = ( + jnp.sum(jnp.log(2 * jnp.pi * noise_real ** 2)) + + jnp.sum(jnp.log(2 * jnp.pi * noise_imag ** 2)) + ) -with timer.section("log_evidence_eager"): - log_evidence_one_channel = compute_log_evidence( + return -0.5 * ( + chi_squared + regularization_term + log_det_curvature_reg + - log_det_regularization + noise_normalization + ) + + with timer.section("log_evidence_sparse_eager"): + log_evidence_one_channel = compute_log_evidence_sparse( + data_real_jnp, data_imag_jnp, noise_real_jnp, noise_imag_jnp, + reconstruction, data_vector, curvature_matrix, + regularization_matrix, mapper_indices_jnp, + ) + block(log_evidence_one_channel) + + _, log_evidence_one_channel = jit_profile( + compute_log_evidence_sparse, "log_evidence_sparse_jit", data_real_jnp, data_imag_jnp, noise_real_jnp, noise_imag_jnp, - transformed_mm_real_jnp, transformed_mm_imag_jnp, - reconstruction, curvature_matrix, reg_jnp, mapper_indices_jnp, + reconstruction, data_vector, curvature_matrix, + regularization_matrix, mapper_indices_jnp, + ) + log_evidence_per_channel_cost = timer.records[-1][1] / 10 + likelihood_steps.append( + ( + f"fast_chi_squared + log evidence (per channel × {n_channels})", + n_channels * log_evidence_per_channel_cost, + ) ) - block(log_evidence_one_channel) -_, log_evidence_one_channel = jit_profile( - compute_log_evidence, "log_evidence_jit", - data_real_jnp, data_imag_jnp, noise_real_jnp, noise_imag_jnp, - transformed_mm_real_jnp, transformed_mm_imag_jnp, - reconstruction, curvature_matrix, reg_jnp, mapper_indices_jnp, -) -log_evidence_per_channel_cost = timer.records[-1][1] / 10 -likelihood_steps.append( - ( - f"Mapped recon + log evidence (per channel × {n_channels})", - n_channels * log_evidence_per_channel_cost, + print(f" channel 0 log_evidence (step-by-step) = {float(log_evidence_one_channel):.6f}") + + # --------------------------------------------------------------- + # Correctness check (sparse path) + # --------------------------------------------------------------- + # Recompute per-channel log_evidence from each channel's sparse + # inversion matrices; sum and compare to FitInterferometer.log_evidence. + + log_evidence_check_per_channel = [] + for c, f in enumerate(fit_list): + inv_c = f.inversion + mm_c = jnp.asarray(inv_c.mapping_matrix) + di_c = jnp.asarray(dataset_list[c].sparse_operator.dirty_image) + dv_c = jnp.dot(mm_c.T, di_c) + cm_c = jnp.asarray(inv_c.curvature_matrix) + rm_c = jnp.array(inv_c.regularization_matrix) + rc_c = jnp.asarray(inv_c.reconstruction) + mi_c = jnp.array(np.asarray(inv_c.mapper_indices)) + + le_c = compute_log_evidence_sparse( + jnp.array(dataset_list[c].data.real), + jnp.array(dataset_list[c].data.imag), + jnp.array(dataset_list[c].noise_map.real), + jnp.array(dataset_list[c].noise_map.imag), + rc_c, dv_c, cm_c, rm_c, mi_c, + ) + log_evidence_check_per_channel.append(float(le_c)) + + cube_log_evidence_check = float(sum(log_evidence_check_per_channel)) + print( + f"\n cube log_evidence (per-step recompute) = {cube_log_evidence_check:.6f}" ) -) + print(f" cube log_evidence (reference) = {cube_log_evidence_ref:.6f}") -print(f" channel 0 log_evidence (step-by-step) = {log_evidence_one_channel}") - -# Correctness check: recompute per-channel log_evidence using each channel's -# inversion matrices and sum to get the cube log-evidence. Should match the -# summed eager FitInterferometer.log_evidence at rtol=1e-4. -log_evidence_check_per_channel = [] -for c, f in enumerate(fit_list): - inv_c = f.inversion - tm_real = jnp.real(jnp.asarray(inv_c.operated_mapping_matrix)) - tm_imag = jnp.imag(jnp.asarray(inv_c.operated_mapping_matrix)) - le_c = compute_log_evidence( - jnp.array(dataset_list[c].data.real), - jnp.array(dataset_list[c].data.imag), - jnp.array(dataset_list[c].noise_map.real), - jnp.array(dataset_list[c].noise_map.imag), - tm_real, tm_imag, - jnp.asarray(inv_c.reconstruction), - jnp.asarray(inv_c.curvature_matrix), - jnp.array(inv_c.regularization_matrix), - jnp.array(np.asarray(inv_c.mapper_indices)), - ) - log_evidence_check_per_channel.append(float(le_c)) - -cube_log_evidence_check = float(sum(log_evidence_check_per_channel)) -print( - f"\n cube log_evidence (per-step recompute) = {cube_log_evidence_check:.6f}" -) -print(f" cube log_evidence (reference) = {cube_log_evidence_ref:.6f}") - -np.testing.assert_allclose( - cube_log_evidence_check, - cube_log_evidence_ref, - rtol=1e-4, - err_msg=( - "Per-step cube log_evidence does not match summed FitInterferometer.log_evidence" - ), -) -print( - " Assertion PASSED: per-step cube log_evidence matches summed " - "FitInterferometer.log_evidence at rtol=1e-4" -) + np.testing.assert_allclose( + cube_log_evidence_check, + cube_log_evidence_ref, + rtol=1e-4, + err_msg=( + "Per-step sparse cube log_evidence does not match summed " + "FitInterferometer.log_evidence" + ), + ) + print( + " Assertion PASSED: per-step sparse cube log_evidence matches summed " + "FitInterferometer.log_evidence at rtol=1e-4" + ) # =================================================================== @@ -791,11 +1122,18 @@ def compute_log_evidence( # Shared-Lᵀ W̃ L optimisation savings estimate: # Moving the curvature matrix from per-channel to shared would save # (n_channels - 1) × per-channel curvature matrix cost. -shared_lwl_savings = (n_channels - 1) * curvature_matrix_per_channel +# In sparse mode the isolated curvature cost is not available. +if curvature_matrix_per_channel is not None: + shared_lwl_savings = (n_channels - 1) * curvature_matrix_per_channel +else: + shared_lwl_savings = None print("-" * 70) print(f" {'TOTAL (step-by-step cube cost)':<{max_label}} {step_total:>12.6f} s") -print(f" {f'Shared-Lᵀ W̃ L savings (curvature only, est.)':<{max_label}} {shared_lwl_savings:>12.6f} s") +if shared_lwl_savings is not None: + print(f" {f'Shared-Lᵀ W̃ L savings (curvature only, est.)':<{max_label}} {shared_lwl_savings:>12.6f} s") +else: + print(f" Shared-Lᵀ W̃ L savings: N/A (run at SMA scale for per-step granularity)") print("=" * 70) # --- Save results dictionary --- @@ -806,6 +1144,7 @@ def compute_log_evidence( "instrument": instrument, "model": "delaunay", "n_channels": n_channels, + "dense_breakdown": dense_breakdown_feasible, "configuration": { "pixel_scale_arcsec": pixel_scale, "mask_radius_arcsec": mask_radius, @@ -815,6 +1154,7 @@ def compute_log_evidence( "delaunay_vertices": int(n_mesh_vertices), "edge_zeroed_pixels": int(edge_pixels_total), "regularization_coefficient": regularization_coefficient, + "transformer_chunk_size": transformer_chunk_size, }, "cube_log_evidence_eager": cube_log_evidence_ref, "log_evidence_per_channel_eager": [float(le) for le in log_evidence_per_channel], @@ -858,13 +1198,14 @@ def compute_log_evidence( fontsize=9, ) -ax.axvline( - shared_lwl_savings, - color="#8172B2", - linestyle=":", - linewidth=1.5, - label=f"Shared-Lᵀ W̃ L savings est.: {shared_lwl_savings:.6f} s", -) +if shared_lwl_savings is not None: + ax.axvline( + shared_lwl_savings, + color="#8172B2", + linestyle=":", + linewidth=1.5, + label=f"Shared-Lᵀ W̃ L savings est.: {shared_lwl_savings:.6f} s", + ) ax.set_yticks(y_pos) ax.set_yticklabels(labels, fontsize=10) @@ -882,7 +1223,8 @@ def compute_log_evidence( f"step-by-step total: {step_total:.6f} s", fontsize=9, ) -ax.legend(loc="lower right", fontsize=9) +if shared_lwl_savings is not None: + ax.legend(loc="lower right", fontsize=9) ax.margins(x=0.18) fig.tight_layout() diff --git a/likelihood_breakdown/interferometer/delaunay.py b/likelihood_breakdown/interferometer/delaunay.py index 75e971f..2f3309a 100644 --- a/likelihood_breakdown/interferometer/delaunay.py +++ b/likelihood_breakdown/interferometer/delaunay.py @@ -84,7 +84,7 @@ instrument = "sma" # <-- change this to profile a different instrument -hilbert_pixels = 1000 # 1000-tier production fiducial for Hilbert + Delaunay +hilbert_pixels = 1500 # 1500-tier production fiducial (matches imaging/datacube) regularization_coefficient = 1.0 @@ -172,24 +172,34 @@ def jit_profile(func, label, *args, n_repeats=10): radius=mask_radius, ) +transformer_chunk_size = INSTRUMENTS[instrument].get("transformer_chunk_size", None) + + +def _build_transformer(uv_wavelengths, real_space_mask): + """Inject per-instrument chunk_size into TransformerNUFFT without needing a + transformer_kwargs API on Interferometer.from_fits. Required for alma_high + (5M visibilities) to cap the nufftax gather buffer (PyAutoArray#330).""" + return al.TransformerNUFFT( + uv_wavelengths=uv_wavelengths, + real_space_mask=real_space_mask, + chunk_size=transformer_chunk_size, + ) + + with timer.section("dataset_load"): dataset = al.Interferometer.from_fits( data_path=dataset_path / "data.fits", noise_map_path=dataset_path / "noise_map.fits", uv_wavelengths_path=dataset_path / "uv_wavelengths.fits", real_space_mask=real_space_mask, - transformer_class=al.TransformerDFT, - # DFT is intentional even at ALMA-scale visibility counts — profiling - # the JAX-traceable path is the goal, NUFFT (pynufft) is not yet - # JIT-friendly. - raise_error_dft_visibilities_limit=False, + transformer_class=_build_transformer, ) with timer.section("apply_sparse_operator"): - # Precompute the NUFFT precision-matrix preload so per-fit curvature - # assembly uses the FFT-based sparse path instead of dense DFT for every - # source pixel. Unblocked by PyAutoArray#316 (the Pmax > 1 extent-indexing - # fix); on Delaunay this was previously guarded with NotImplementedError. + # Precompute the W~ precision-matrix preload + dirty image so per-fit + # curvature assembly uses the FFT-based sparse path instead of the dense + # transformed_mapping_matrix. The NUFFT keeps the one-time dirty-image + # setup tractable at ALMA-scale visibility counts (PyAutoArray#329). dataset = dataset.apply_sparse_operator(use_jax=True, show_progress=True) n_visibilities = dataset.uv_wavelengths.shape[0] diff --git a/likelihood_runtime/datacube/delaunay.py b/likelihood_runtime/datacube/delaunay.py index aa3f0b0..67d0d30 100644 --- a/likelihood_runtime/datacube/delaunay.py +++ b/likelihood_runtime/datacube/delaunay.py @@ -136,7 +136,7 @@ # n_channels = 34 matches the prior Hannah ALMA cube fiducial. For quick # iteration on the smaller sma dataset, drop this to 4. n_channels = 34 -hilbert_pixels = 500 # 500-tier production fiducial per channel (× n_channels) +hilbert_pixels = 1500 # 1500-tier production fiducial (matches imaging/interferometer) regularization_coefficient = 1.0 diff --git a/likelihood_runtime/interferometer/delaunay.py b/likelihood_runtime/interferometer/delaunay.py index 76d637c..c30579b 100644 --- a/likelihood_runtime/interferometer/delaunay.py +++ b/likelihood_runtime/interferometer/delaunay.py @@ -132,7 +132,7 @@ instrument = _cli.instrument or "sma" # default; override via --instrument -hilbert_pixels = 1000 # 1000-tier production fiducial for Hilbert + Delaunay +hilbert_pixels = 1500 # 1500-tier production fiducial (matches imaging/datacube) regularization_coefficient = 1.0 diff --git a/likelihood_runtime/interferometer/pixelization.py b/likelihood_runtime/interferometer/pixelization.py index 4ae7d59..a23c6d7 100644 --- a/likelihood_runtime/interferometer/pixelization.py +++ b/likelihood_runtime/interferometer/pixelization.py @@ -181,20 +181,34 @@ def jit_profile(func, label, *args, n_repeats=10): radius=mask_radius, ) +transformer_chunk_size = INSTRUMENTS[instrument].get("transformer_chunk_size", None) + + +def _build_transformer(uv_wavelengths, real_space_mask): + """Inject per-instrument chunk_size into TransformerNUFFT without needing a + transformer_kwargs API on Interferometer.from_fits. Required for alma_high + (5M visibilities) to cap the nufftax gather buffer (PyAutoArray#330).""" + return al.TransformerNUFFT( + uv_wavelengths=uv_wavelengths, + real_space_mask=real_space_mask, + chunk_size=transformer_chunk_size, + ) + + with timer.section("dataset_load"): dataset = al.Interferometer.from_fits( data_path=dataset_path / "data.fits", noise_map_path=dataset_path / "noise_map.fits", uv_wavelengths_path=dataset_path / "uv_wavelengths.fits", real_space_mask=real_space_mask, - transformer_class=al.TransformerDFT, + transformer_class=_build_transformer, ) with timer.section("apply_sparse_operator"): - # Precompute the FFT-based sparse operator so per-fit curvature assembly - # uses the sparse path instead of dense DFT on every likelihood call. - # Compatible only with TransformerDFT / TransformerNUFFTPyNUFFT today - # (see PyAutoArray/autoarray/dataset/interferometer/dataset.py:261). + # Precompute the W~ precision-matrix preload + dirty image so per-fit + # curvature assembly uses the FFT-based sparse path instead of the dense + # transformed_mapping_matrix. The NUFFT keeps the one-time dirty-image + # setup tractable at ALMA-scale visibility counts (PyAutoArray#329). dataset = dataset.apply_sparse_operator(use_jax=True, show_progress=True) n_visibilities = dataset.uv_wavelengths.shape[0]