Skip to content

feat: autolens point_source JAX visualization scripts#91

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/point-source-jax-viz
May 14, 2026
Merged

feat: autolens point_source JAX visualization scripts#91
Jammy2211 merged 1 commit into
mainfrom
feature/point-source-jax-viz

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Adds the missing point_source/visualization.py, point_source/visualization_jax.py, and point_source/modeling_visualization_jit.py scripts in autolens_workspace_test. Phase 1B of PyAutoPrompt/issued/jax_visualization.md — point-source was the only autolens dataset type with zero visualization coverage of any flavour. Together with the upstream PyAutoLens fix (#506), this closes the autolens imaging + interferometer + point_source coverage for the JIT-cached visualization path.

Closes #90.

Scripts Changed

  • scripts/point_source/visualization.py (NEW) — NumPy baseline. Builds an Isothermal+PointFlux model and calls VisualizerPoint.visualize directly with use_jax=False. Asserts fit.png lands on disk.
  • scripts/point_source/visualization_jax.py (NEW) — JIT-cached path. Adds enable_pytrees() + register_model(model) and constructs al.AnalysisPoint(use_jax=True, use_jax_for_visualization=True, fit_positions_cls=al.FitPositionsImagePairAll, ...). No try/except wrapper. Asserts fit.png lands.
  • scripts/point_source/modeling_visualization_jit.py (NEW) — two-part:
    • Part 1 caching probe (calls analysis.fit_for_visualization(instance) twice, asserts second is significantly faster, asserts _jitted_fit_from is not None).
    • Part 2 live Nautilus run with n_live=50, n_like_max=1500, iterations_per_quick_update=500 and image-plane chi-squared. Cleans both scripts/point_source/images/modeling_visualization_jit/ AND output/scripts/point_source/images/modeling_visualization_jit/point_image_plane/ before the Nautilus call so reruns don't silently resume from cached samples.csv (lesson from PR feat: autolens interferometer JAX visualization scripts #87).
  • config/build/env_vars.yaml — add point_source/visualization_jax (unset PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS) and point_source/modeling_visualization_jit (unset PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_TEST_MODE, PYAUTO_FAST_PLOTS) overrides, mirroring the imaging and interferometer analogues.

All three scripts use FitPositionsImagePairAll (image-plane chi-squared) only. Source-plane chi-squared (FitPositionsSource) is still JIT-blocked per scripts/CLAUDE.md L132 — out of scope for this PR; will need its own follow-up when the blocker lifts.

The model deliberately omits a free cosmology parameter — cosmology distance calc caches global state and breaks JIT round-trip (per the existing jax_likelihood_functions/point_source/image_plane.py L144-147 caveat).

Upstream PR

PyAutoLabs/PyAutoLens#506AnalysisPoint.__init__ **kwargs passthrough. Required for these scripts to construct AnalysisPoint(use_jax_for_visualization=True, ...) without TypeError. Library-first merge gate enforces this PR merges only after #506.

Test Plan

  • python scripts/point_source/visualization.pyNumPy point-source visualization produced fit.png.
  • JAX_ENABLE_X64=True python scripts/point_source/visualization_jax.pyPILOT SUCCEEDED — JAX-backed point-source visualization produced fit.png.
  • JAX_ENABLE_X64=True python scripts/point_source/modeling_visualization_jit.py — JIT cache fires during Nautilus, fit.png files produced: 1.

🤖 Generated with Claude Code

Closes #90.

- scripts/point_source/visualization.py (NEW) — NumPy baseline calling
  VisualizerPoint.visualize on the NumPy fall-through path.
- scripts/point_source/visualization_jax.py (NEW) — JIT-cached
  fit_for_visualization via use_jax_for_visualization=True, with
  enable_pytrees() + register_model(model) from the start.
- scripts/point_source/modeling_visualization_jit.py (NEW) — caching
  probe + live Nautilus quick-update run with image-plane chi-squared
  (FitPositionsImagePairAll). Cleans output/<path_prefix>/<name>/
  before launching Nautilus so reruns don't silently skip live sampling.
- config/build/env_vars.yaml — point_source/visualization_jax +
  point_source/modeling_visualization_jit overrides.

Phase 1B of PyAutoPrompt/issued/jax_visualization.md. Image-plane only
(source-plane JIT still blocked per scripts/CLAUDE.md L132).

Depends on PyAutoLabs/PyAutoLens#506.
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 14, 2026
@Jammy2211 Jammy2211 merged commit cdb9ea8 into main May 14, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/point-source-jax-viz branch May 14, 2026 11:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: autolens point_source JAX visualization coverage

1 participant