Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion config/build/env_vars.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ overrides:
unset: [PYAUTO_SMALL_DATASETS]
- pattern: "imaging/model_fit"
unset: [PYAUTO_SMALL_DATASETS]
- pattern: "imaging/visualization"
- pattern: "imaging/visualization.py"
unset: [PYAUTO_SMALL_DATASETS]
# visualization_jax exercises the jit-cached fit_for_visualization path
# (registered model + autoarray pytrees). It must run with JAX enabled —
# PYAUTO_DISABLE_JAX=1 would silently flip use_jax flags off and the
# script would no-op.
- pattern: "imaging/visualization_jax"
unset: [PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS]
- pattern: "jax_grad/imaging_lp"
unset: [PYAUTO_SMALL_DATASETS]
- pattern: "jax_grad/imaging_mge"
Expand Down
49 changes: 22 additions & 27 deletions scripts/imaging/visualization_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
Goal
----
Run ``VisualizerImaging.visualize`` with JAX enabled end-to-end, gated behind
the new ``use_jax_for_visualization`` flag on ``Analysis``. The parametric MGE
source is used deliberately (simplest case — no pixelization, no inversion).

This is **Path C** from the plan: ``fit_from`` runs on the eager JAX path
(``use_jax=True`` makes ``_xp`` be ``jnp``) and returns a ``FitImaging``
backed by ``jax.Array`` objects. Matplotlib-bound plotters materialise arrays
to NumPy at the boundary. No ``jax.jit`` is applied to ``fit_from`` — the
full-JIT path (Path A) depends on ``FitImaging`` itself becoming a pytree,
which is tracked as a separate task (see
``PyAutoPrompt/issued/fit_imaging_pytree.md``).
``use_jax_for_visualization=True`` on ``Analysis``. After PyAutoLens #443
(2026-04-19) the imaging visualizer dispatches through
``analysis.fit_for_visualization``, which lazily wraps ``fit_from`` in
``jax.jit``. To trace across that boundary the model and fit return type
must be JAX pytrees, so this script enables pytree registration before
constructing the model. Parametric MGE source — simplest case (no
pixelization, no inversion).

Scope
-----
Expand All @@ -28,7 +25,6 @@
"""

import shutil
import traceback
from os import path
from pathlib import Path
from types import SimpleNamespace
Expand All @@ -42,8 +38,11 @@

import autofit as af
import autolens as al
from autofit.jax.pytrees import enable_pytrees, register_model
from autolens.imaging.model.visualizer import VisualizerImaging

enable_pytrees()


"""
__Dataset__
Expand Down Expand Up @@ -103,6 +102,8 @@

model = af.Collection(galaxies=af.Collection(lens=lens, source=source))

register_model(model)


"""
__Analysis__
Expand Down Expand Up @@ -137,19 +138,13 @@
instance = model.instance_from_prior_medians()

print("Running VisualizerImaging.visualize with use_jax_for_visualization=True ...")
try:
VisualizerImaging.visualize(
analysis=analysis,
paths=paths,
instance=instance,
during_analysis=False,
)
assert (image_path / "parametric" / "fit.png").exists() or (
image_path / "fit.png"
).exists(), "fit.png was not produced"
print("PILOT SUCCEEDED — JAX-backed visualization produced fit.png/tracer.png.")
except Exception:
print("PILOT FAILED — traceback below:")
print("=" * 72)
traceback.print_exc()
print("=" * 72)
VisualizerImaging.visualize(
analysis=analysis,
paths=paths,
instance=instance,
during_analysis=False,
)
assert (image_path / "parametric" / "fit.png").exists() or (
image_path / "fit.png"
).exists(), "fit.png was not produced"
print("PILOT SUCCEEDED — JAX-backed visualization produced fit.png/tracer.png.")
Loading