Skip to content

feat: autolens interferometer JAX visualization coverage #86

@Jammy2211

Description

@Jammy2211

Overview

The autolens_workspace_test interferometer dataset has a NumPy visualization.py but no visualization_jax.py or modeling_visualization_jit.py — those exist only for imaging. PyAutoLens's AnalysisInterferometer already dispatches via analysis.fit_for_visualization (visualizer.py:96, 209) and has full pytree registration, so the wiring is in place; this task adds the two missing test scripts. Phase 1A of PyAutoPrompt/issued/jax_visualization.md (the JAX visualization roadmap).

Plan

  • Add scripts/interferometer/visualization_jax.py — mirrors scripts/imaging/visualization_jax.py but uses AnalysisInterferometer and the interferometer simulator. Must include enable_pytrees() + register_model(model) from the start (lesson from PR fix: register_model in visualization_jax.py to actually exercise JIT path #85).
  • Add scripts/interferometer/modeling_visualization_jit.py — mirrors the imaging analogue's two-part shape (caching probe + live Nautilus run with iterations_per_quick_update, asserting the JIT cache fires and subplot_fit.png lands).
  • Reuse the existing simulator under scripts/jax_likelihood_functions/interferometer/.
  • Update config/build/env_vars.yaml to add imaging/visualization_jax-style overrides for the interferometer scripts (unset PYAUTO_DISABLE_JAX etc.).
  • Verify both scripts run locally with JAX enabled — they should print PILOT SUCCEEDED and produce the expected PNGs.
Detailed implementation plan

Affected Repositories

  • autolens_workspace_test (primary, only repo)

Work Classification

Workspace

Branch Survey

Repository Current Branch Dirty?
./autolens_workspace_test main README.md (unrelated automated version bump — not in scope for this task)

Suggested branch: feature/autolens-interferometer-jax-viz
Worktree root: ~/Code/PyAutoLabs-wt/autolens-interferometer-jax-viz/ (created later by /start_workspace)

Implementation Steps

  1. scripts/interferometer/visualization_jax.py — mirror the structure of scripts/imaging/visualization_jax.py (post-PR-fix: register_model in visualization_jax.py to actually exercise JIT path #85), with these differences:

    • Use al.AnalysisInterferometer instead of al.AnalysisImaging.
    • Use the interferometer dataset path: dataset/interferometer/jax_test.
    • Auto-simulate via subprocess.run([sys.executable, "scripts/jax_likelihood_functions/interferometer/simulator.py"], check=True) if missing.
    • Build the model with the parametric MGE pattern from scripts/jax_likelihood_functions/interferometer/mge.py (lens + source MGE).
    • Critical (lesson from PR fix: register_model in visualization_jax.py to actually exercise JIT path #85): import from autofit.jax.pytrees import enable_pytrees, register_model and call enable_pytrees() at module level + register_model(model) after building the model. Without these, jax.jit(fit_from) cannot trace the ModelInstance.
    • Critical: No try/except wrapper — call VisualizerInterferometer.visualize directly so any failure surfaces loudly. Assert subplot_fit.png (or fit.png, depending on which the interferometer plotter produces — verify via the imaging analogue's pattern).
    • Reuse config_source/visualize/plots.yaml from the existing interferometer/visualization.py so the visualization output is bounded.
  2. scripts/interferometer/modeling_visualization_jit.py — mirror scripts/imaging/modeling_visualization_jit.py:

    • Part 1: caching probe — call analysis.fit_for_visualization(instance) twice and assert the second call is significantly faster than the first (_jitted_fit_from is cached on the analysis instance).
    • Part 2: live Nautilus run with iterations_per_quick_update=500, n_like_max=1500, n_live=50, asserts fit.png files land under the output search root.
    • Same enable_pytrees() + register_model(model) setup as Part 1.
    • Use MGE linear light profiles (matches the imaging analogue's Part 2) so linear_light_profile_intensity_dict is exercised on the interferometer side.
  3. config/build/env_vars.yaml — add overrides analogous to the autolens imaging entries:

    • pattern: "interferometer/visualization_jax" → unset PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_FAST_PLOTS (mirrors the new imaging/visualization_jax entry from PR fix: register_model in visualization_jax.py to actually exercise JIT path #85).
    • pattern: "interferometer/modeling_visualization_jit" → unset PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_TEST_MODE, PYAUTO_FAST_PLOTS (mirrors the existing imaging/modeling_visualization_jit entry).
  4. Verification — run both with JAX enabled inside the worktree:

    cd $WT_ROOT/autolens_workspace_test
    JAX_ENABLE_X64=True python scripts/interferometer/visualization_jax.py
    JAX_ENABLE_X64=True python scripts/interferometer/modeling_visualization_jit.py

    Both must print PILOT SUCCEEDED (or the jit-cache pass message) and produce the expected PNGs.

Known risk

autolens_workspace_test interferometer scripts have historically been red on CI because of a gitignored sma.fits fixture (complete.md L1080). The simulator at scripts/jax_likelihood_functions/interferometer/simulator.py may depend on this fixture. If running the simulator from the worktree fails on a missing fixture, the autogalaxy port (complete.md L970) wrote a self-contained simulator using np.random.default_rng(seed=1) for 200 synthetic baselines — that's the fallback to mirror.

Key Files

  • scripts/interferometer/visualization_jax.py (NEW)
  • scripts/interferometer/modeling_visualization_jit.py (NEW)
  • config/build/env_vars.yaml (EDIT — add 2 override entries)

Reference patterns

  • scripts/imaging/visualization_jax.py — post-PR-fix: register_model in visualization_jax.py to actually exercise JIT path #85 pattern with enable_pytrees + register_model
  • scripts/imaging/modeling_visualization_jit.py — caching-probe + live-Nautilus pattern
  • scripts/jax_likelihood_functions/interferometer/mge.py — interferometer MGE model setup
  • PyAutoLens autolens/interferometer/model/visualizer.py:96, 209 — already dispatches via fit_for_visualization
  • complete.md L970 — autogalaxy port self-contained simulator pattern (fallback if sma.fits is missing)

Original Prompt

Click to expand starting prompt

(elided — full text in PyAutoPrompt/issued/jax_viz_interferometer_coverage.md)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions