diff --git a/config/build/env_vars.yaml b/config/build/env_vars.yaml index f9729e3..98b0529 100644 --- a/config/build/env_vars.yaml +++ b/config/build/env_vars.yaml @@ -35,8 +35,14 @@ overrides: unset: [PYAUTO_SMALL_DATASETS] - pattern: "imaging/model_fit" unset: [PYAUTO_SMALL_DATASETS] - - pattern: "imaging/visualization" + - pattern: "imaging/visualization.py" unset: [PYAUTO_SMALL_DATASETS] + # visualization_jax exercises the jit-cached fit_for_visualization path + # (registered model + autoarray pytrees). It must run with JAX enabled — + # PYAUTO_DISABLE_JAX=1 would silently flip use_jax flags off and the + # script would no-op. + - pattern: "imaging/visualization_jax" + unset: [PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS] - pattern: "jax_grad/imaging_lp" unset: [PYAUTO_SMALL_DATASETS] - pattern: "jax_grad/imaging_mge" diff --git a/scripts/imaging/visualization_jax.py b/scripts/imaging/visualization_jax.py index bebb197..940461a 100644 --- a/scripts/imaging/visualization_jax.py +++ b/scripts/imaging/visualization_jax.py @@ -7,16 +7,13 @@ Goal ---- Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end, gated behind -the new ``use_jax_for_visualization`` flag on ``Analysis``. The parametric MGE -source is used deliberately (simplest case — no pixelization, no inversion). - -This is **Path C** from the plan: ``fit_from`` runs on the eager JAX path -(``use_jax=True`` makes ``_xp`` be ``jnp``) and returns a ``FitImaging`` -backed by ``jax.Array`` objects. Matplotlib-bound plotters materialise arrays -to NumPy at the boundary. No ``jax.jit`` is applied to ``fit_from`` — the -full-JIT path (Path A) depends on ``FitImaging`` itself becoming a pytree, -which is tracked as a separate task (see -``PyAutoPrompt/issued/fit_imaging_pytree.md``). +``use_jax_for_visualization=True`` on ``Analysis``. After PyAutoLens #443 +(2026-04-19) the imaging visualizer dispatches through +``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in +``jax.jit``. To trace across that boundary the model and fit return type +must be JAX pytrees, so this script enables pytree registration before +constructing the model. Parametric MGE source — simplest case (no +pixelization, no inversion). Scope ----- @@ -28,7 +25,6 @@ """ import shutil -import traceback from os import path from pathlib import Path from types import SimpleNamespace @@ -42,8 +38,11 @@ import autofit as af import autolens as al +from autofit.jax.pytrees import enable_pytrees, register_model from autolens.imaging.model.visualizer import VisualizerImaging +enable_pytrees() + """ __Dataset__ @@ -103,6 +102,8 @@ model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) +register_model(model) + """ __Analysis__ @@ -137,19 +138,13 @@ instance = model.instance_from_prior_medians() print("Running VisualizerImaging.visualize with use_jax_for_visualization=True ...") -try: - VisualizerImaging.visualize( - analysis=analysis, - paths=paths, - instance=instance, - during_analysis=False, - ) - assert (image_path / "parametric" / "fit.png").exists() or ( - image_path / "fit.png" - ).exists(), "fit.png was not produced" - print("PILOT SUCCEEDED — JAX-backed visualization produced fit.png/tracer.png.") -except Exception: - print("PILOT FAILED — traceback below:") - print("=" * 72) - traceback.print_exc() - print("=" * 72) +VisualizerImaging.visualize( + analysis=analysis, + paths=paths, + instance=instance, + during_analysis=False, +) +assert (image_path / "parametric" / "fit.png").exists() or ( + image_path / "fit.png" +).exists(), "fit.png was not produced" +print("PILOT SUCCEEDED — JAX-backed visualization produced fit.png/tracer.png.")