Overview
Both autolens_workspace_test/scripts/imaging/visualization_jax.py and autogalaxy_workspace_test/scripts/imaging/visualization_jax.py are silently broken under JAX: they catch their own JIT failure, print "PILOT FAILED", and exit with code 0, so test runners never notice. Discovered 2026-05-08 while smoke-verifying the autogalaxy dispatch swap (PyAutoGalaxy #390) — the autolens version has been broken since the equivalent dispatch swap in PyAutoLens #443 (2026-04-19). Root cause: neither script calls enable_pytrees() + register_model(model), so jax.jit(fit_from) cannot trace the ModelInstance argument. The working sibling modeling_visualization_jit.py does call them — this PR brings visualization_jax.py to parity in both repos and removes the swallowed-exception pattern so future regressions fail loud.
Plan
- Add
enable_pytrees() import + call at the top of each visualization_jax.py, mirroring modeling_visualization_jit.py lines 43-45.
- Add
register_model(model) after the model is composed and before the analysis is constructed.
- Drop the
try: ... except Exception: traceback.print_exc() block — replace with direct VisualizerImaging.visualize(...) + assertion. If the JIT path breaks again, the script should fail loudly.
- Update each workspace's
config/build/env_vars.yaml: split the existing imaging/visualization override into a NumPy-only override (matches imaging/visualization.py) and a JAX-enabled override (matches imaging/visualization_jax) that also unsets PYAUTO_DISABLE_JAX. Mirrors the imaging/modeling_visualization_jit override that already does this.
- Verify each script passes when run with JAX enabled (no
PYAUTO_DISABLE_JAX).
Detailed implementation plan
Affected Repositories
- autolens_workspace_test (primary, lead PR)
- autogalaxy_workspace_test (sibling fix)
Work Classification
Workspace
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./autolens_workspace_test |
main |
README.md (unrelated automated version bump — not in scope for this task) |
| ./autogalaxy_workspace_test |
main |
clean |
Suggested branch: feature/viz-jax-pytree-fix
Worktree root: ~/Code/PyAutoLabs-wt/viz-jax-pytree-fix/ (created later by /start_workspace)
Implementation Steps
-
autolens_workspace_test/scripts/imaging/visualization_jax.py:
- After
import autolens as al, add from autofit.jax.pytrees import enable_pytrees, register_model and a top-level enable_pytrees() call.
- After the
model = af.Collection(...) line, add register_model(model).
- Drop the
try/except wrapper around VisualizerImaging.visualize(...). Replace with the call + the existing assertion.
- Drop the now-unused
import traceback.
-
autogalaxy_workspace_test/scripts/imaging/visualization_jax.py:
- Same edits, with
import autogalaxy as ag instead of autolens.
-
autolens_workspace_test/config/build/env_vars.yaml:
- Split the existing
imaging/visualization override:
- Keep one entry matching
imaging/visualization.py (substring match — only matches the NumPy script) with the same unset: [PYAUTO_FAST_PLOTS, PYAUTO_SMALL_DATASETS].
- Add a new entry matching
imaging/visualization_jax that unsets PYAUTO_DISABLE_JAX in addition.
-
autogalaxy_workspace_test/config/build/env_vars.yaml:
-
Verify each script passes when run with JAX enabled (no PYAUTO_DISABLE_JAX). Both should print PILOT SUCCEEDED.
Key Files
autolens_workspace_test/scripts/imaging/visualization_jax.py — add register_model + remove swallowed exception
autogalaxy_workspace_test/scripts/imaging/visualization_jax.py — same
autolens_workspace_test/config/build/env_vars.yaml — split imaging/visualization override
autogalaxy_workspace_test/config/build/env_vars.yaml — same
Reference (working sibling pattern)
autolens_workspace_test/scripts/imaging/modeling_visualization_jit.py:43-45 — the working enable_pytrees() + register_model() pattern
autogalaxy_workspace_test/config/build/env_vars.yaml:51-52 — existing imaging/modeling_visualization_jit env override that unsets PYAUTO_DISABLE_JAX (template for the new imaging/visualization_jax entry)
- PyAutoGalaxy #390 (merged 2026-05-08) — the dispatch swap that surfaced this issue on the autogalaxy side
- PyAutoLens #443 (2026-04-19) — earlier autolens dispatch swap that left the autolens script in the same broken state
Original Prompt
Click to expand starting prompt
Both `autolens_workspace_test/scripts/imaging/visualization_jax.py` and
`autogalaxy_workspace_test/scripts/imaging/visualization_jax.py` are
**silently broken under JAX**: when run with `PYAUTO_DISABLE_JAX` unset
(i.e. JAX actually enabled), they fail with:
TypeError: Error interpreting argument to
as an abstract array. The problematic value is of type
<class 'autofit.mapper.model.ModelInstance'> and was passed to the function
at path instance.
Both scripts catch their own exception, print "PILOT FAILED", and then exit
with code 0 — so test runners don't notice. Discovered 2026-05-08 while
smoke-verifying the autogalaxy dispatch swap (PyAutoGalaxy #390). The
autolens version has been broken since the equivalent dispatch swap in
PR #443 (2026-04-19).
__Why this matters__
These scripts are meant to be the dedicated smoke verifier for the
`use_jax_for_visualization=True` JIT path on imaging — but in their
current state they silently no-op:
- They are **not in `smoke_tests.txt`** in either workspace_test repo.
- The workspaces' `env_vars.yaml` matches `imaging/visualization` and
applies `PYAUTO_DISABLE_JAX=1` (default), which silently flips both
`use_jax` flags off in `Analysis.__init__`. The script falls through to
NumPy and "passes" without exercising the JIT path it claims to test.
Net effect: the JIT-cached visualization path has zero smoke coverage on
the standalone (non-Nautilus) call site. Coverage of the JIT path in a
live search exists via `imaging/modeling_visualization_jit.py` (which
DOES register the model — see lines 43-45) but the standalone
`fit_for_visualization` path is untested.
__Root cause__
`fit_for_visualization` lazily wraps `fit_from` in `jax.jit` when
`use_jax_for_visualization=True` (see
`@PyAutoFit/autofit/non_linear/analysis/analysis.py:114-122`). For
`jax.jit` to trace the call, the `instance: ModelInstance` argument must
be pytree-registered. `autofit.jax.pytrees` provides `enable_pytrees()`
+ `register_model(model)` for that — but neither `visualization_jax.py`
script calls them.
The working sibling `modeling_visualization_jit.py` does:
```python
from autofit.jax.pytrees import enable_pytrees, register_model
enable_pytrees()
...
register_model(model_mge)
Both visualization_jax.py scripts need the same.
What to change
Apply the same fix to both files:
@autolens_workspace_test/scripts/imaging/visualization_jax.py
@autogalaxy_workspace_test/scripts/imaging/visualization_jax.py
For each:
- Add at the top, after
import autolens as al (or autogalaxy as ag):
from autofit.jax.pytrees import enable_pytrees, register_model
enable_pytrees()
- After the
model = af.Collection(...) line, before analysis = ag.AnalysisImaging(...):
- Replace the script-level try/except that swallows the JIT error with a
hard re-raise so future regressions are caught instead of being printed
and silently exit 0:
VisualizerImaging.visualize(
analysis=analysis,
paths=paths,
instance=instance,
during_analysis=False,
)
assert (image_path / "parametric" / "fit.png").exists() or (
image_path / "fit.png"
).exists(), "fit.png was not produced"
print("PILOT SUCCEEDED — JAX-backed visualization produced fit.png/tracer.png.")
Drop the try: ... except Exception: traceback.print_exc() block. If
the JIT path breaks again, the script should fail loudly.
env_vars.yaml — second-order question
Both workspaces' config/build/env_vars.yaml have an imaging/visualization
override that applies to both visualization.py and visualization_jax.py
(substring match) and currently sets PYAUTO_DISABLE_JAX=1 for both.
After this fix lands, visualization_jax.py will actually need JAX to
run (it was always supposed to, but silently fell through). Two options:
- Option A — narrow the existing override. Change the pattern to
imaging/visualization.py so it only matches the NumPy script; add a
new override for imaging/visualization_jax that unsets PYAUTO_DISABLE_JAX
alongside the existing unsets. Mirrors the pattern used for
imaging/modeling_visualization_jit.
- Option B — leave env unchanged, mark the script auto-skipping in CI.
Worse — keeps the silent-skip behaviour we are trying to remove.
Take Option A.
Verification
After both files are fixed, run each WITH PYAUTO_DISABLE_JAX unset:
# autolens
cd autolens_workspace_test
JAX_ENABLE_X64=True NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib \
python scripts/imaging/visualization_jax.py
# autogalaxy
cd autogalaxy_workspace_test
JAX_ENABLE_X64=True NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib \
python scripts/imaging/visualization_jax.py
Both should print PILOT SUCCEEDED ... and produce fit.png (or
parametric/fit.png). Re-run with the env_vars.yaml-resolved env (which
will now include unset: [PYAUTO_DISABLE_JAX, ...] for the
imaging/visualization_jax pattern) to confirm the build path also runs
JAX.
The pre-existing imaging/visualization.py (NumPy) and
imaging/modeling_visualization_jit.py must continue to pass.
Out of scope
- No library changes. The failure is purely test-script + env config.
- No additions to
smoke_tests.txt. Per the user's smoke-test policy
(MEMORY.md: "Smoke tests are a small curated subset"), promoting these
scripts into smoke is a separate decision.
- No
register_model audit elsewhere. If other workspace_test scripts
also try use_jax_for_visualization=True without registering, they
hit the same issue — but that's a follow-up prompt to author after
this one lands and the pattern is established.
Reference
@PyAutoFit/autofit/jax/pytrees.py — enable_pytrees, register_model
@autolens_workspace_test/scripts/imaging/modeling_visualization_jit.py:43-45 — working sibling pattern
@autogalaxy_workspace_test/scripts/imaging/modeling_visualization_jit.py — working sibling
@PyAutoFit/autofit/non_linear/analysis/analysis.py:82-122 — fit_for_visualization JIT dispatch
- PyAutoGalaxy #390 — dispatch swap that surfaced this issue (2026-05-08)
- PyAutoLens #443 — earlier autolens dispatch swap (2026-04-19) that left autolens script in same broken state
PyAutoPrompt/z_features/jax_visualization.md — sequenced roadmap (informal Phase 1 follow-up)
Overview
Both
autolens_workspace_test/scripts/imaging/visualization_jax.pyandautogalaxy_workspace_test/scripts/imaging/visualization_jax.pyare silently broken under JAX: they catch their own JIT failure, print "PILOT FAILED", and exit with code 0, so test runners never notice. Discovered 2026-05-08 while smoke-verifying the autogalaxy dispatch swap (PyAutoGalaxy #390) — the autolens version has been broken since the equivalent dispatch swap in PyAutoLens #443 (2026-04-19). Root cause: neither script callsenable_pytrees()+register_model(model), sojax.jit(fit_from)cannot trace theModelInstanceargument. The working siblingmodeling_visualization_jit.pydoes call them — this PR bringsvisualization_jax.pyto parity in both repos and removes the swallowed-exception pattern so future regressions fail loud.Plan
enable_pytrees()import + call at the top of eachvisualization_jax.py, mirroringmodeling_visualization_jit.pylines 43-45.register_model(model)after the model is composed and before the analysis is constructed.try: ... except Exception: traceback.print_exc()block — replace with directVisualizerImaging.visualize(...)+ assertion. If the JIT path breaks again, the script should fail loudly.config/build/env_vars.yaml: split the existingimaging/visualizationoverride into a NumPy-only override (matchesimaging/visualization.py) and a JAX-enabled override (matchesimaging/visualization_jax) that also unsetsPYAUTO_DISABLE_JAX. Mirrors theimaging/modeling_visualization_jitoverride that already does this.PYAUTO_DISABLE_JAX).Detailed implementation plan
Affected Repositories
Work Classification
Workspace
Branch Survey
Suggested branch:
feature/viz-jax-pytree-fixWorktree root:
~/Code/PyAutoLabs-wt/viz-jax-pytree-fix/(created later by/start_workspace)Implementation Steps
autolens_workspace_test/scripts/imaging/visualization_jax.py:import autolens as al, addfrom autofit.jax.pytrees import enable_pytrees, register_modeland a top-levelenable_pytrees()call.model = af.Collection(...)line, addregister_model(model).try/exceptwrapper aroundVisualizerImaging.visualize(...). Replace with the call + the existing assertion.import traceback.autogalaxy_workspace_test/scripts/imaging/visualization_jax.py:import autogalaxy as aginstead of autolens.autolens_workspace_test/config/build/env_vars.yaml:imaging/visualizationoverride:imaging/visualization.py(substring match — only matches the NumPy script) with the sameunset: [PYAUTO_FAST_PLOTS, PYAUTO_SMALL_DATASETS].imaging/visualization_jaxthat unsetsPYAUTO_DISABLE_JAXin addition.autogalaxy_workspace_test/config/build/env_vars.yaml:Verify each script passes when run with JAX enabled (no PYAUTO_DISABLE_JAX). Both should print
PILOT SUCCEEDED.Key Files
autolens_workspace_test/scripts/imaging/visualization_jax.py— add register_model + remove swallowed exceptionautogalaxy_workspace_test/scripts/imaging/visualization_jax.py— sameautolens_workspace_test/config/build/env_vars.yaml— splitimaging/visualizationoverrideautogalaxy_workspace_test/config/build/env_vars.yaml— sameReference (working sibling pattern)
autolens_workspace_test/scripts/imaging/modeling_visualization_jit.py:43-45— the workingenable_pytrees()+register_model()patternautogalaxy_workspace_test/config/build/env_vars.yaml:51-52— existingimaging/modeling_visualization_jitenv override that unsets PYAUTO_DISABLE_JAX (template for the newimaging/visualization_jaxentry)Original Prompt
Click to expand starting prompt
TypeError: Error interpreting argument to
as an abstract array. The problematic value is of type
<class 'autofit.mapper.model.ModelInstance'> and was passed to the function
at path instance.
Both
visualization_jax.pyscripts need the same.What to change
Apply the same fix to both files:
@autolens_workspace_test/scripts/imaging/visualization_jax.py@autogalaxy_workspace_test/scripts/imaging/visualization_jax.pyFor each:
import autolens as al(orautogalaxy as ag):model = af.Collection(...)line, beforeanalysis = ag.AnalysisImaging(...):hard re-raise so future regressions are caught instead of being printed
and silently exit 0:
try: ... except Exception: traceback.print_exc()block. Ifthe JIT path breaks again, the script should fail loudly.
env_vars.yaml — second-order question
Both workspaces'
config/build/env_vars.yamlhave animaging/visualizationoverride that applies to both
visualization.pyandvisualization_jax.py(substring match) and currently sets
PYAUTO_DISABLE_JAX=1for both.After this fix lands,
visualization_jax.pywill actually need JAX torun (it was always supposed to, but silently fell through). Two options:
imaging/visualization.pyso it only matches the NumPy script; add anew override for
imaging/visualization_jaxthat unsetsPYAUTO_DISABLE_JAXalongside the existing unsets. Mirrors the pattern used for
imaging/modeling_visualization_jit.Worse — keeps the silent-skip behaviour we are trying to remove.
Take Option A.
Verification
After both files are fixed, run each WITH
PYAUTO_DISABLE_JAXunset:Both should print
PILOT SUCCEEDED ...and producefit.png(orparametric/fit.png). Re-run with the env_vars.yaml-resolved env (whichwill now include
unset: [PYAUTO_DISABLE_JAX, ...]for theimaging/visualization_jaxpattern) to confirm the build path also runsJAX.
The pre-existing
imaging/visualization.py(NumPy) andimaging/modeling_visualization_jit.pymust continue to pass.Out of scope
smoke_tests.txt. Per the user's smoke-test policy(
MEMORY.md: "Smoke tests are a small curated subset"), promoting thesescripts into smoke is a separate decision.
register_modelaudit elsewhere. If other workspace_test scriptsalso try
use_jax_for_visualization=Truewithout registering, theyhit the same issue — but that's a follow-up prompt to author after
this one lands and the pattern is established.
Reference
@PyAutoFit/autofit/jax/pytrees.py—enable_pytrees,register_model@autolens_workspace_test/scripts/imaging/modeling_visualization_jit.py:43-45— working sibling pattern@autogalaxy_workspace_test/scripts/imaging/modeling_visualization_jit.py— working sibling@PyAutoFit/autofit/non_linear/analysis/analysis.py:82-122—fit_for_visualizationJIT dispatchPyAutoPrompt/z_features/jax_visualization.md— sequenced roadmap (informal Phase 1 follow-up)