diff --git a/config/build/env_vars.yaml b/config/build/env_vars.yaml index 98b0529..b2faf63 100644 --- a/config/build/env_vars.yaml +++ b/config/build/env_vars.yaml @@ -54,14 +54,23 @@ overrides: # visualization.py asserts subplot PNG files exist on disk, but # PYAUTO_FAST_PLOTS=1 short-circuits subplot_save() in PyAutoArray so # no file is ever written. - - pattern: "interferometer/visualization" + - pattern: "interferometer/visualization.py" unset: [PYAUTO_FAST_PLOTS] + # interferometer/visualization_jax exercises the jit-cached + # fit_for_visualization path on the interferometer side. PYAUTO_DISABLE_JAX=1 + # would silently flip use_jax flags off — script needs JAX enabled. + - pattern: "interferometer/visualization_jax" + unset: [PYAUTO_DISABLE_JAX, PYAUTO_FAST_PLOTS, PYAUTO_SMALL_DATASETS] # modeling_visualization_jit{,_delaunay,_rectangular} test the JIT-cached # visualization path: needs JAX enabled (Part 1 asserts log_likelihood is # a jax.Array), the full-resolution mask, a real Nautilus run for Part 2, # and savefig active (Part 2 asserts fit.png lands on disk). - pattern: "imaging/modeling_visualization_jit" unset: [PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_TEST_MODE, PYAUTO_FAST_PLOTS] + # interferometer/modeling_visualization_jit — same intent as the imaging + # analogue but for the interferometer JIT path. + - pattern: "interferometer/modeling_visualization_jit" + unset: [PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_TEST_MODE, PYAUTO_FAST_PLOTS] # multi/visualization_imaging loads the 150x150 lens_sersic dataset from # disk via Imaging.from_fits (which does not honour SMALL_DATASETS) and # builds its mask via al.Mask2D.circular (which DOES cap to 15x15 under diff --git a/scripts/interferometer/modeling_visualization_jit.py b/scripts/interferometer/modeling_visualization_jit.py new file mode 100644 index 0000000..0124eb9 --- /dev/null +++ b/scripts/interferometer/modeling_visualization_jit.py @@ -0,0 +1,271 @@ +""" +End-to-end test: jit-cached visualization during a real Nautilus model-fit. +========================================================================== + +Exercises the full Path A pipeline shipped across PyAutoArray #288, PyAutoLens +#445, and the PyAutoFit change that turns ``Analysis.fit_for_visualization`` +into a lazily-cached ``jax.jit(self.fit_from)``. + +This test runs in two parts: + +Part 1 — **MGE caching probe.** Uses an MGE linear lens (GaussianGradient basis ++ NFWSph mass + ExternalShear) and MGE parametric source model. Calls +``analysis.fit_for_visualization(instance)`` twice and asserts the second call +is much faster than the first (confirming the compiled function is cached on the +analysis instance, not recompiled per visualization). + +Part 2 — **Live Nautilus quick-update with MGE linear profiles.** Runs a real +(short) Nautilus fit with an MGE lens (``GaussianGradient`` basis + ``NFWSph`` +mass) and MGE source — both use linear light profiles whose ``intensity`` is +solved by the inversion. With the ``pytree_token`` fix on +``LightProfileLinear``, the ``linear_light_profile_intensity_dict`` lookup +survives the JAX pytree round-trip and no ``KeyError`` is raised. Asserts that +``fit.png`` files land on disk, proving the JIT-cached fit_for_visualization +fires correctly during the live search callback. + +This script deliberately opts in with +``AnalysisInterferometer(use_jax=True, use_jax_for_visualization=True)``. +Default model-fit scripts elsewhere in the workspace leave both flags at +``False`` and are therefore untouched by this change. +""" + +import shutil +import time +from os import path +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +import autolens as al +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() + + +""" +__Dataset__ + +Re-use the ``simple`` interferometer dataset. Auto-simulate if missing. +""" +mask_radius = 3.5 + +real_space_mask = al.Mask2D.circular( + shape_native=(256, 256), + pixel_scales=0.1, + radius=mask_radius, +) + +dataset_path = path.join("dataset", "interferometer", "simple") + +if al.util.dataset.should_simulate(dataset_path): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/interferometer/simulator.py"], + check=True, + ) + +dataset = al.Interferometer.from_fits( + data_path=path.join(dataset_path, "data.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + uv_wavelengths_path=path.join(dataset_path, "uv_wavelengths.fits"), + real_space_mask=real_space_mask, + transformer_class=al.TransformerDFT, +) + +positions = al.Grid2DIrregular( + al.from_json(file_path=path.join(dataset_path, "positions.json")) +) + + +""" +============================================================================ +Part 1 — MGE caching probe +============================================================================ + +Model: MGE linear lens (Basis of GaussianGradient + NFWSph mass + ExternalShear) +and MGE parametric source. Mirrors the linear MGE pattern from the imaging +analogue at ``scripts/imaging/modeling_visualization_jit.py``. +""" +print("\n" + "=" * 72) +print("Part 1: MGE caching probe") +print("=" * 72) + +mass_mge = af.Model(al.mp.NFWSph) + +total_gaussians = 3 +log10_sigma_list = np.linspace(-2, np.log10(mask_radius), total_gaussians) + +centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +centre_1 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) + +gaussian_list = af.Collection( + af.Model(al.lmp_linear.GaussianGradient) for _ in range(total_gaussians) +) +for i, gaussian in enumerate(gaussian_list): + gaussian.centre.centre_0 = centre_0 + gaussian.centre.centre_1 = centre_1 + gaussian.ell_comps = gaussian_list[0].ell_comps + gaussian.sigma = 10 ** log10_sigma_list[i] + gaussian.mass_to_light_ratio = 10.0 + gaussian.mass_to_light_gradient = 1.0 + +bulge_mge = af.Model(al.lp_basis.Basis, profile_list=list(gaussian_list)) +shear_mge = af.Model(al.mp.ExternalShear) + +lens_mge = af.Model( + al.Galaxy, redshift=0.5, bulge=bulge_mge, mass=mass_mge, shear=shear_mge +) + +source_bulge_mge = al.model_util.mge_model_from( + mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=False +) +source_mge = af.Model(al.Galaxy, redshift=1.0, bulge=source_bulge_mge) + +model_mge = af.Collection(galaxies=af.Collection(lens=lens_mge, source=source_mge)) + +register_model(model_mge) + +analysis_mge = al.AnalysisInterferometer( + dataset=dataset, + positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], + use_jax=True, + use_jax_for_visualization=True, +) + +instance_mge = model_mge.instance_from_prior_medians() + +t0 = time.perf_counter() +fit_1 = analysis_mge.fit_for_visualization(instance_mge) +jax.block_until_ready(fit_1.log_likelihood) +t1 = time.perf_counter() +compile_time = t1 - t0 +print(f"First call (compile + run): {compile_time:.3f}s") +print(f" log_likelihood leaf type: {type(fit_1.log_likelihood).__name__}") +assert isinstance( + fit_1.log_likelihood, jnp.ndarray +), f"expected jax.Array, got {type(fit_1.log_likelihood)}" + +t0 = time.perf_counter() +fit_2 = analysis_mge.fit_for_visualization(instance_mge) +jax.block_until_ready(fit_2.log_likelihood) +t1 = time.perf_counter() +cached_time = t1 - t0 +print(f"Second call (cached): {cached_time:.3f}s") +print(f"Speedup: {compile_time / max(cached_time, 1e-9):.1f}x") + +assert cached_time < compile_time * 0.5, ( + f"Cached call ({cached_time:.3f}s) not faster than compile " + f"({compile_time:.3f}s) — JIT cache is not being hit." +) +assert ( + analysis_mge._jitted_fit_from is not None +), "expected _jitted_fit_from to be cached on the analysis instance after first call" +print("PASS: MGE jit-cached fit_for_visualization works and is reused.") + + +""" +============================================================================ +Part 2 — Live Nautilus quick-update with MGE linear light profiles +============================================================================ + +Model: MGE linear lens (Basis of GaussianGradient + NFWSph mass) and MGE +parametric source. Linear light profiles are used, so the +``linear_light_profile_intensity_dict`` lookup is exercised during +visualization. With the ``pytree_token`` fix on ``LightProfileLinear``, +dict lookups survive the JAX pytree round-trip and no ``KeyError`` is raised. + +The live search fires quick-update visualization every +``iterations_per_quick_update`` calls; we verify fit.png lands on disk. +""" +print("\n" + "=" * 72) +print("Part 2: Live Nautilus with MGE linear profiles + jit-visualization") +print("=" * 72) + +mass_mge2 = af.Model(al.mp.NFWSph) + +total_gaussians2 = 3 +log10_sigma_list2 = np.linspace(-2, np.log10(mask_radius), total_gaussians2) + +centre_0_2 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) +centre_1_2 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1) + +gaussian_list2 = af.Collection( + af.Model(al.lmp_linear.GaussianGradient) for _ in range(total_gaussians2) +) +for i, gaussian in enumerate(gaussian_list2): + gaussian.centre.centre_0 = centre_0_2 + gaussian.centre.centre_1 = centre_1_2 + gaussian.ell_comps = gaussian_list2[0].ell_comps + gaussian.sigma = 10 ** log10_sigma_list2[i] + gaussian.mass_to_light_ratio = 10.0 + gaussian.mass_to_light_gradient = 1.0 + +bulge_mge2 = af.Model(al.lp_basis.Basis, profile_list=list(gaussian_list2)) + +lens_mge2 = af.Model(al.Galaxy, redshift=0.5, bulge=bulge_mge2, mass=mass_mge2) + +source_bulge_mge2 = al.model_util.mge_model_from( + mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=False +) +source_mge2 = af.Model(al.Galaxy, redshift=1.0, bulge=source_bulge_mge2) + +model_mge2 = af.Collection(galaxies=af.Collection(lens=lens_mge2, source=source_mge2)) + +register_model(model_mge2) + +analysis_mge2 = al.AnalysisInterferometer( + dataset=dataset, + positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], + use_jax=True, + use_jax_for_visualization=True, +) + +output_root = Path("scripts") / "interferometer" / "images" / "modeling_visualization_jit" +if output_root.exists(): + shutil.rmtree(output_root) +output_root.mkdir(parents=True) + +# Also clean the autofit search output. Without this, Nautilus resumes from +# the previous run's cached samples.csv and skips live sampling — so the +# quick-update visualizer never fires, _jitted_fit_from is never set, and +# the assertion below would fail on every rerun. Force a fresh run. +output_search_root = Path("output") / output_root / "mge_linear" +if output_search_root.exists(): + shutil.rmtree(output_search_root) + +search = af.Nautilus( + path_prefix=str(output_root), + name="mge_linear", + n_live=50, + n_like_max=1500, + iterations_per_quick_update=500, + number_of_cores=1, +) + +print("Running Nautilus ...") +result = search.fit(model=model_mge2, analysis=analysis_mge2) + +# The Nautilus output goes to output////image/ +# The quick-update visualizer writes fit.png during each quick update. +produced_pngs = list(output_search_root.rglob("fit.png")) +print(f"fit.png files produced: {len(produced_pngs)}") +for p in produced_pngs: + print(f" {p}") +assert len(produced_pngs) > 0, ( + f"no fit.png produced under {output_search_root} — " + "quick-update visualization did not fire" +) +assert ( + analysis_mge2._jitted_fit_from is not None +), "expected _jitted_fit_from to be cached on the analysis instance during search" + +print( + "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " + "with MGE linear profiles, fit.png written, no KeyError from " + "linear_light_profile_intensity_dict lookup." +) diff --git a/scripts/interferometer/visualization_jax.py b/scripts/interferometer/visualization_jax.py new file mode 100644 index 0000000..2a6f2ad --- /dev/null +++ b/scripts/interferometer/visualization_jax.py @@ -0,0 +1,139 @@ +""" +Visualization JAX Pilot: Interferometer Analysis +================================================= + +Pilot for https://github.com/PyAutoLabs/PyAutoFit/issues/1227. + +Goal +---- +Run ``VisualizerInterferometer.visualize`` with JAX enabled end-to-end, gated +behind ``use_jax_for_visualization=True`` on ``Analysis``. After PyAutoLens #443 +the interferometer visualizer dispatches through +``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in +``jax.jit`` (autolens/interferometer/model/visualizer.py:96). To trace across +that boundary the model and fit return type must be JAX pytrees, so this script +enables pytree registration before constructing the model. Parametric MGE +source — simplest case (no PSF convolution; interferometer operates in Fourier +space via DFT, no pixelization, no inversion). + +Scope +----- +- Parametric MGE source only. +- Calls ``VisualizerInterferometer.visualize`` only (not ``visualize_before_fit``). +- Re-uses the ``simple`` dataset from ``jax_likelihood_functions/interferometer``. +- Uses the default plot config (no bespoke config_source override). +""" + +import shutil +from os import path +from pathlib import Path +from types import SimpleNamespace + +import autofit as af +import autolens as al +from autofit.jax.pytrees import enable_pytrees, register_model +from autolens.interferometer.model.visualizer import VisualizerInterferometer + +enable_pytrees() + + +""" +__Dataset__ + +Re-use the ``simple`` interferometer dataset used by +``jax_likelihood_functions/interferometer``. Auto-simulate if missing. +""" +mask_radius = 3.0 + +real_space_mask = al.Mask2D.circular( + shape_native=(256, 256), + pixel_scales=0.1, + radius=mask_radius, +) + +dataset_path = path.join("dataset", "interferometer", "simple") + +if al.util.dataset.should_simulate(dataset_path): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/interferometer/simulator.py"], + check=True, + ) + +dataset = al.Interferometer.from_fits( + data_path=path.join(dataset_path, "data.fits"), + noise_map_path=path.join(dataset_path, "noise_map.fits"), + uv_wavelengths_path=path.join(dataset_path, "uv_wavelengths.fits"), + real_space_mask=real_space_mask, + transformer_class=al.TransformerDFT, +) + +positions = al.Grid2DIrregular( + al.from_json(file_path=path.join(dataset_path, "positions.json")) +) + + +""" +__Model__ + +Lens: Isothermal mass + ExternalShear (matching the interferometer mge.py +pattern; no lens light). Source: MGE parametric bulge. +""" +mass = af.Model(al.mp.Isothermal) +shear = af.Model(al.mp.ExternalShear) +lens = af.Model(al.Galaxy, redshift=0.5, mass=mass, shear=shear) + +source_bulge = al.model_util.mge_model_from( + mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=False +) +source = af.Model(al.Galaxy, redshift=1.0, bulge=source_bulge) + +model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + +register_model(model) + + +""" +__Analysis__ + +``use_jax=True`` turns on the JAX ``_xp`` path; ``use_jax_for_visualization=True`` +tells the search-level visualization path to wrap ``fit_from`` in ``jax.jit`` +via the new ``Analysis.fit_for_visualization`` helper. +""" +analysis = al.AnalysisInterferometer( + dataset=dataset, + positions_likelihood_list=[al.PositionsLH(threshold=0.4, positions=positions)], + use_jax=True, + use_jax_for_visualization=True, + title_prefix="JAX_PILOT", +) + + +""" +__Paths__ +""" +image_path = Path("scripts") / "interferometer" / "images" / "visualization_jax" +if image_path.exists(): + shutil.rmtree(image_path) +image_path.mkdir(parents=True) +output_path = image_path / "output" +output_path.mkdir(parents=True) +paths = SimpleNamespace(image_path=image_path, output_path=output_path) + + +""" +__Run visualize on the eager-JAX fit__ +""" +instance = model.instance_from_prior_medians() + +print("Running VisualizerInterferometer.visualize with use_jax_for_visualization=True ...") +VisualizerInterferometer.visualize( + analysis=analysis, + paths=paths, + instance=instance, + during_analysis=False, +) +assert (image_path / "fit.png").exists(), "fit.png was not produced" +print("PILOT SUCCEEDED — JAX-backed interferometer visualization produced fit.png/tracer.png.")