From 74b6809c2a7731a749413ab847b31eee5116040d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 14 May 2026 12:27:23 +0100 Subject: [PATCH] feat: add point_source JAX visualization scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #90. - scripts/point_source/visualization.py (NEW) — NumPy baseline calling VisualizerPoint.visualize on the NumPy fall-through path. - scripts/point_source/visualization_jax.py (NEW) — JIT-cached fit_for_visualization via use_jax_for_visualization=True, with enable_pytrees() + register_model(model) from the start. - scripts/point_source/modeling_visualization_jit.py (NEW) — caching probe + live Nautilus quick-update run with image-plane chi-squared (FitPositionsImagePairAll). Cleans output/// before launching Nautilus so reruns don't silently skip live sampling. - config/build/env_vars.yaml — point_source/visualization_jax + point_source/modeling_visualization_jit overrides. Phase 1B of PyAutoPrompt/issued/jax_visualization.md. Image-plane only (source-plane JIT still blocked per scripts/CLAUDE.md L132). Depends on PyAutoLabs/PyAutoLens#506. --- config/build/env_vars.yaml | 7 + .../modeling_visualization_jit.py | 227 ++++++++++++++++++ scripts/point_source/visualization.py | 122 ++++++++++ scripts/point_source/visualization_jax.py | 139 +++++++++++ 4 files changed, 495 insertions(+) create mode 100644 scripts/point_source/modeling_visualization_jit.py create mode 100644 scripts/point_source/visualization.py create mode 100644 scripts/point_source/visualization_jax.py diff --git a/config/build/env_vars.yaml b/config/build/env_vars.yaml index b2faf632..da5efe85 100644 --- a/config/build/env_vars.yaml +++ b/config/build/env_vars.yaml @@ -80,3 +80,10 @@ overrides: # full resolution; just unset the cap. - pattern: "multi/visualization_imaging" unset: [PYAUTO_SMALL_DATASETS] + # point_source/visualization_jax exercises the jit-cached + # fit_for_visualization path on the point-source side. + - pattern: "point_source/visualization_jax" + unset: [PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS] + # point_source/modeling_visualization_jit — live Nautilus + JIT path. + - pattern: "point_source/modeling_visualization_jit" + unset: [PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_TEST_MODE, PYAUTO_FAST_PLOTS] diff --git a/scripts/point_source/modeling_visualization_jit.py b/scripts/point_source/modeling_visualization_jit.py new file mode 100644 index 00000000..641757e1 --- /dev/null +++ b/scripts/point_source/modeling_visualization_jit.py @@ -0,0 +1,227 @@ +""" +End-to-end test: jit-cached visualization during a real Nautilus model-fit. +============================================================================ + +Exercises the full JAX visualization pipeline for the point-source analysis +path: ``AnalysisPoint(use_jax=True, use_jax_for_visualization=True)`` with +an ``Isothermal`` lens mass and ``PointFlux`` source (image-plane chi-squared +via ``FitPositionsImagePairAll``). + +This test runs in two parts: + +Part 1 — **Caching probe.** Calls ``analysis.fit_for_visualization(instance)`` +twice and asserts the second call is much faster than the first (confirming +the compiled function is cached on the analysis instance, not recompiled per +visualization call). Also asserts ``analysis._jitted_fit_from is not None`` +after the first call. + +Part 2 — **Live Nautilus quick-update.** Runs a real (short) Nautilus fit. +The live search fires quick-update visualization every +``iterations_per_quick_update`` likelihood evaluations; we verify that +``fit.png`` lands on disk under the Nautilus output tree, proving the +JIT-cached ``fit_for_visualization`` fires correctly during the live +search callback. + +This script deliberately opts in with +``AnalysisPoint(use_jax=True, use_jax_for_visualization=True)``. +Default model-fit scripts elsewhere in the workspace leave both flags at +``False`` and are therefore untouched. +""" + +import shutil +import time +from os import path +from pathlib import Path + +import jax +import jax.numpy as jnp + +import autofit as af +import autolens as al +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() + + +""" +__Dataset__ +""" +dataset_path = Path("dataset") / "point_source" / "simple" + +if al.util.dataset.should_simulate(str(dataset_path)): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/point_source/simulator.py"], + check=True, + ) + +dataset = al.from_json( + file_path=dataset_path / "point_dataset_positions_only.json", +) + + +""" +__Point Solver__ +""" +grid = al.Grid2D.uniform(shape_native=(100, 100), pixel_scales=0.2) + +solver = al.PointSolver.for_grid( + grid=grid, pixel_scale_precision=0.001, magnification_threshold=0.1 +) + + +""" +============================================================================ +Part 1 — Caching probe +============================================================================ + +Model: Isothermal lens mass + PointFlux source. Same tight priors as the +other point-source scripts so the prior-median instance produces multiple +images. No free cosmology (breaks JIT via global-state distance caching). +""" +print("\n" + "=" * 72) +print("Part 1: Point-source caching probe") +print("=" * 72) + +mass = af.Model(al.mp.Isothermal) +mass.centre.centre_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.centre.centre_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.einstein_radius = af.UniformPrior(lower_limit=1.5, upper_limit=1.8) + +lens = af.Model(al.Galaxy, redshift=0.5, mass=mass) + +point_0 = af.Model(al.ps.PointFlux) +point_0.centre.centre_0 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) +point_0.centre.centre_1 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) + +source = af.Model(al.Galaxy, redshift=1.0, point_0=point_0) + +model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + +register_model(model) + +analysis = al.AnalysisPoint( + dataset=dataset, + solver=solver, + fit_positions_cls=al.FitPositionsImagePairAll, + use_jax=True, + use_jax_for_visualization=True, +) + +instance = model.instance_from_prior_medians() + +t0 = time.perf_counter() +fit_1 = analysis.fit_for_visualization(instance) +jax.block_until_ready(fit_1.log_likelihood) +t1 = time.perf_counter() +compile_time = t1 - t0 +print(f"First call (compile + run): {compile_time:.3f}s") +print(f" log_likelihood leaf type: {type(fit_1.log_likelihood).__name__}") +assert isinstance( + fit_1.log_likelihood, jnp.ndarray +), f"expected jax.Array, got {type(fit_1.log_likelihood)}" + +t0 = time.perf_counter() +fit_2 = analysis.fit_for_visualization(instance) +jax.block_until_ready(fit_2.log_likelihood) +t1 = time.perf_counter() +cached_time = t1 - t0 +print(f"Second call (cached): {cached_time:.3f}s") +print(f"Speedup: {compile_time / max(cached_time, 1e-9):.1f}x") + +assert cached_time < compile_time * 0.5, ( + f"Cached call ({cached_time:.3f}s) not faster than compile " + f"({compile_time:.3f}s) — JIT cache is not being hit." +) +assert ( + analysis._jitted_fit_from is not None +), "expected _jitted_fit_from to be cached on the analysis instance after first call" +print("PASS: Point-source jit-cached fit_for_visualization works and is reused.") + + +""" +============================================================================ +Part 2 — Live Nautilus quick-update +============================================================================ + +Rebuild the model fresh (register_model on the new instance), create a +separate analysis object, and run a short Nautilus fit. The search fires +quick-update visualization every ``iterations_per_quick_update`` calls; +we assert that ``fit.png`` lands on disk under the Nautilus output tree. +""" +print("\n" + "=" * 72) +print("Part 2: Live Nautilus + jit-visualization for point source") +print("=" * 72) + +mass2 = af.Model(al.mp.Isothermal) +mass2.centre.centre_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass2.centre.centre_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass2.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass2.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass2.einstein_radius = af.UniformPrior(lower_limit=1.5, upper_limit=1.8) + +lens2 = af.Model(al.Galaxy, redshift=0.5, mass=mass2) + +point_02 = af.Model(al.ps.PointFlux) +point_02.centre.centre_0 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) +point_02.centre.centre_1 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) + +source2 = af.Model(al.Galaxy, redshift=1.0, point_0=point_02) + +model2 = af.Collection(galaxies=af.Collection(lens=lens2, source=source2)) + +register_model(model2) + +analysis_run = al.AnalysisPoint( + dataset=dataset, + solver=solver, + fit_positions_cls=al.FitPositionsImagePairAll, + use_jax=True, + use_jax_for_visualization=True, +) + +output_root = Path("scripts") / "point_source" / "images" / "modeling_visualization_jit" +if output_root.exists(): + shutil.rmtree(output_root) +output_root.mkdir(parents=True) + +# Also clean the autofit search output so Nautilus performs live sampling +# instead of resuming from a cached samples.csv — without this the +# quick-update visualizer never fires on reruns. +output_search_root = Path("output") / output_root / "point_image_plane" +if output_search_root.exists(): + shutil.rmtree(output_search_root) + +search = af.Nautilus( + path_prefix=str(output_root), + name="point_image_plane", + n_live=50, + n_like_max=1500, + iterations_per_quick_update=500, + number_of_cores=1, +) + +print("Running Nautilus ...") +result = search.fit(model=model2, analysis=analysis_run) + +# Nautilus writes quick-update images to output////image/ +produced_pngs = list(output_search_root.rglob("fit.png")) +print(f"fit.png files produced: {len(produced_pngs)}") +for p in produced_pngs: + print(f" {p}") +assert len(produced_pngs) > 0, ( + f"no fit.png produced under {output_search_root} — " + "quick-update visualization did not fire" +) +assert ( + analysis_run._jitted_fit_from is not None +), "expected _jitted_fit_from to be cached on the analysis instance during search" + +print( + "\nPASS: jit-cached fit_for_visualization fires during Nautilus quick updates " + f"for point source, fit.png written." +) diff --git a/scripts/point_source/visualization.py b/scripts/point_source/visualization.py new file mode 100644 index 00000000..ed35733c --- /dev/null +++ b/scripts/point_source/visualization.py @@ -0,0 +1,122 @@ +""" +Visualization NumPy: Point Source Analysis +========================================== + +Tests that ``VisualizerPoint.visualize`` runs end-to-end on a +``PointDataset`` using the NumPy (non-JAX) code path and that ``fit.png`` +lands on disk. + +Uses the ``simple/point_dataset_positions_only.json`` dataset (auto-simulated +if missing) with an ``Isothermal`` lens mass and ``PointFlux`` source — the +same model that is proven to JIT end-to-end in +``scripts/jax_likelihood_functions/point_source/image_plane.py``. + +No ``try/except`` — any failure in the visualizer surfaces immediately. +""" + +import shutil +from pathlib import Path +from types import SimpleNamespace + +import autofit as af +import autolens as al +from autolens.point.model.visualizer import VisualizerPoint + + +""" +__Dataset__ +""" +dataset_path = Path("dataset") / "point_source" / "simple" + +if al.util.dataset.should_simulate(str(dataset_path)): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/point_source/simulator.py"], + check=True, + ) + +dataset = al.from_json( + file_path=dataset_path / "point_dataset_positions_only.json", +) + + +""" +__Point Solver__ +""" +grid = al.Grid2D.uniform(shape_native=(100, 100), pixel_scales=0.2) + +solver = al.PointSolver.for_grid( + grid=grid, pixel_scale_precision=0.001, magnification_threshold=0.1 +) + + +""" +__Model__ + +Tight priors centred on the true values so the prior-median instance +produces a sensible lens configuration (multiple images exist). +No free cosmology — cosmology distance caching breaks JIT. +""" +mass = af.Model(al.mp.Isothermal) +mass.centre.centre_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.centre.centre_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.einstein_radius = af.UniformPrior(lower_limit=1.5, upper_limit=1.8) + +lens = af.Model(al.Galaxy, redshift=0.5, mass=mass) + +point_0 = af.Model(al.ps.PointFlux) +point_0.centre.centre_0 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) +point_0.centre.centre_1 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) + +source = af.Model(al.Galaxy, redshift=1.0, point_0=point_0) + +model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + + +""" +__Analysis__ + +Explicit NumPy path (use_jax=False). +""" +analysis = al.AnalysisPoint( + dataset=dataset, + solver=solver, + fit_positions_cls=al.FitPositionsImagePairAll, + use_jax=False, +) + + +""" +__Paths__ +""" +image_path = Path("scripts") / "point_source" / "images" / "visualization" +if image_path.exists(): + shutil.rmtree(image_path) +image_path.mkdir(parents=True) +output_path = image_path / "output" +output_path.mkdir(parents=True) +paths = SimpleNamespace(image_path=image_path, output_path=output_path) + + +""" +__Visualize__ +""" +instance = model.instance_from_prior_medians() + +print("Running VisualizerPoint.visualize (NumPy) ...") +VisualizerPoint.visualize( + analysis=analysis, + paths=paths, + instance=instance, + during_analysis=False, +) + +print("Files in image_path:", list(image_path.iterdir())) +assert (image_path / "fit.png").exists(), ( + f"fit.png was not produced. Files present: {list(image_path.iterdir())}" +) +print("NumPy point-source visualization produced fit.png.") diff --git a/scripts/point_source/visualization_jax.py b/scripts/point_source/visualization_jax.py new file mode 100644 index 00000000..90eb5b73 --- /dev/null +++ b/scripts/point_source/visualization_jax.py @@ -0,0 +1,139 @@ +""" +Visualization JAX Pilot: Point Source Analysis +=============================================== + +Pilot for the JAX-backed visualization path on ``PointDataset``. + +Goal +---- +Run ``VisualizerPoint.visualize`` with ``use_jax=True`` and +``use_jax_for_visualization=True`` on ``AnalysisPoint``. The point +visualizer dispatches through ``analysis.fit_for_visualization``, which +lazily wraps ``fit_from`` in ``jax.jit``. To trace across that boundary the +model and fit return type must be JAX pytrees, so this script enables pytree +registration before constructing the model. + +Scope +----- +- ``Isothermal`` lens mass + ``PointFlux`` source (image-plane chi-squared). +- Calls ``VisualizerPoint.visualize`` only (not ``visualize_before_fit``). +- Re-uses the ``simple/point_dataset_positions_only.json`` dataset. +- No ``try/except`` wrapper — failure surfaces immediately. +""" + +import shutil +from pathlib import Path +from types import SimpleNamespace + +import autofit as af +import autolens as al +from autofit.jax.pytrees import enable_pytrees, register_model +from autolens.point.model.visualizer import VisualizerPoint + +enable_pytrees() + + +""" +__Dataset__ +""" +dataset_path = Path("dataset") / "point_source" / "simple" + +if al.util.dataset.should_simulate(str(dataset_path)): + import subprocess + import sys + + subprocess.run( + [sys.executable, "scripts/jax_likelihood_functions/point_source/simulator.py"], + check=True, + ) + +dataset = al.from_json( + file_path=dataset_path / "point_dataset_positions_only.json", +) + + +""" +__Point Solver__ +""" +grid = al.Grid2D.uniform(shape_native=(100, 100), pixel_scales=0.2) + +solver = al.PointSolver.for_grid( + grid=grid, pixel_scale_precision=0.001, magnification_threshold=0.1 +) + + +""" +__Model__ + +Tight priors centred on the true values so the prior-median instance +produces a sensible lens configuration (multiple images exist). +No free cosmology — cosmology distance caching breaks JIT. +""" +mass = af.Model(al.mp.Isothermal) +mass.centre.centre_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.centre.centre_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.ell_comps.ell_comps_0 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.ell_comps.ell_comps_1 = af.UniformPrior(lower_limit=0.0, upper_limit=0.02) +mass.einstein_radius = af.UniformPrior(lower_limit=1.5, upper_limit=1.8) + +lens = af.Model(al.Galaxy, redshift=0.5, mass=mass) + +point_0 = af.Model(al.ps.PointFlux) +point_0.centre.centre_0 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) +point_0.centre.centre_1 = af.UniformPrior(lower_limit=0.06, upper_limit=0.08) + +source = af.Model(al.Galaxy, redshift=1.0, point_0=point_0) + +model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + +register_model(model) + + +""" +__Analysis__ + +``use_jax=True`` turns on the JAX ``_xp`` path; +``use_jax_for_visualization=True`` tells the visualization path to wrap +``fit_from`` in ``jax.jit`` via ``Analysis.fit_for_visualization``. +``title_prefix`` is passed through via PR #506's **kwargs fix. +""" +analysis = al.AnalysisPoint( + dataset=dataset, + solver=solver, + fit_positions_cls=al.FitPositionsImagePairAll, + use_jax=True, + use_jax_for_visualization=True, + title_prefix="JAX_PILOT", +) + + +""" +__Paths__ +""" +image_path = Path("scripts") / "point_source" / "images" / "visualization_jax" +if image_path.exists(): + shutil.rmtree(image_path) +image_path.mkdir(parents=True) +output_path = image_path / "output" +output_path.mkdir(parents=True) +paths = SimpleNamespace(image_path=image_path, output_path=output_path) + + +""" +__Run visualize on the JAX-backed fit__ +""" +instance = model.instance_from_prior_medians() + +print("Running VisualizerPoint.visualize with use_jax_for_visualization=True ...") +VisualizerPoint.visualize( + analysis=analysis, + paths=paths, + instance=instance, + during_analysis=False, +) + +print("Files in image_path:", list(image_path.iterdir())) +assert (image_path / "fit.png").exists(), ( + f"fit.png was not produced. Files present: {list(image_path.iterdir())}" +) +print("PILOT SUCCEEDED — JAX-backed point-source visualization produced fit.png.")