Skip to content

fix: register_model in visualization_jax.py to actually exercise JIT path#85

Merged
Jammy2211 merged 1 commit intomainfrom
feature/viz-jax-pytree-fix
May 8, 2026
Merged

fix: register_model in visualization_jax.py to actually exercise JIT path#85
Jammy2211 merged 1 commit intomainfrom
feature/viz-jax-pytree-fix

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

scripts/imaging/visualization_jax.py was silently broken under JAX since PyAutoLens #443 (2026-04-19) made VisualizerImaging.visualize dispatch through analysis.fit_for_visualization (which JITs fit_from). Without enable_pytrees() + register_model(model), jax.jit(fit_from) cannot trace the ModelInstance argument and raises TypeError. The script's try/except wrapper caught the failure, printed "PILOT FAILED", and exited 0 — so test runners never noticed.

This PR adds the two missing calls, drops the swallowed-exception pattern so future regressions fail loud, and splits the imaging/visualization env_vars override into a NumPy-only entry and a JAX-only entry that unsets PYAUTO_DISABLE_JAX, mirroring the existing imaging/modeling_visualization_jit override.

Sibling PR in autogalaxy_workspace_test fixes the same latent bug on the autogalaxy side (broken since PyAutoGalaxy #390 on 2026-05-08).

Closes #84.

Scripts Changed

  • scripts/imaging/visualization_jax.py — add enable_pytrees() + register_model(model); drop try/except wrapper; update docstring to reflect the Path-A (jit-traced) intent
  • config/build/env_vars.yaml — split imaging/visualization override; add imaging/visualization_jax entry that unsets PYAUTO_DISABLE_JAX

Test Plan

  • python scripts/imaging/visualization_jax.py (with JAX enabled) — PILOT SUCCEEDED — JAX-backed visualization produced fit.png/tracer.png.

🤖 Generated with Claude Code

Closes #84.

Adds enable_pytrees() + register_model(model) so jax.jit(fit_from) can
trace the ModelInstance argument across the JIT boundary. Without these
the script silently failed with TypeError — the try/except wrapper
printed "PILOT FAILED" and exited 0, so test runners never noticed.

Drops the swallowed-exception pattern so future regressions fail loud.
Splits env_vars.yaml's imaging/visualization override into a NumPy-only
entry (visualization.py) and a JAX-only entry (visualization_jax) that
unsets PYAUTO_DISABLE_JAX, mirroring the existing
imaging/modeling_visualization_jit override.

Sibling fix in autogalaxy_workspace_test ships in parallel.

Verified: PILOT SUCCEEDED with JAX enabled.
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

Sibling fix for autogalaxy side: PyAutoLabs/autogalaxy_workspace_test#37

@Jammy2211 Jammy2211 merged commit bdc8ee1 into main May 8, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/viz-jax-pytree-fix branch May 8, 2026 19:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fix: register_model in visualization_jax.py to actually exercise JIT path

1 participant