Skip to content

feat: autolens point_source JAX visualization coverage #90

@Jammy2211

Description

@Jammy2211

Overview

Point-source is the only autolens dataset type with zero visualization coverage of any flavour in autolens_workspace_test. Phase 1B of the JAX visualization roadmap fills the gap by adding NumPy baseline → JAX → live Nautilus JIT scripts mirroring the imaging analogues. Discovered while planning: al.AnalysisPoint.__init__ has the same **kwargs passthrough gap that PyAutoLens #500 just fixed on AnalysisInterferometer, so this is a Both task — library fix ships first, workspace scripts follow via library-first merge gate.

Plan

  • Library: 2-line **kwargs passthrough fix in al.AnalysisPoint.__init__ (mirrors PR #500's AnalysisInterferometer fix). Closes TypeError: got an unexpected keyword argument 'use_jax_for_visualization' that the JAX scripts would otherwise hit.
  • Workspace sub-step 1: NumPy baseline scripts/point_source/visualization.py — must land first since no NumPy regression baseline exists today.
  • Workspace sub-step 2: scripts/point_source/visualization_jax.py exercising the JIT-cached fit_for_visualization path, image-plane chi-squared only (source-plane is JIT-blocked per scripts/CLAUDE.md L132). Include enable_pytrees() + register_model(model) from the start; no try/except wrapper (lesson from PR fix: register_model in visualization_jax.py to actually exercise JIT path #85).
  • Workspace sub-step 3: scripts/point_source/modeling_visualization_jit.py — caching probe + live Nautilus. Explicit rmtree(output/<path_prefix>/<name>/) before Nautilus so reruns don't silently skip live sampling (lesson from PR feat: autolens interferometer JAX visualization scripts #87).
  • config/build/env_vars.yaml: add point_source/visualization_jax + point_source/modeling_visualization_jit overrides.
  • Source-plane variant gated behind a feasibility probe; defer to a follow-up if the JIT blocker still bites.
Detailed implementation plan

Affected Repositories

  • PyAutoLens (library — AnalysisPoint.__init__ **kwargs passthrough)
  • autolens_workspace_test (workspace — three new scripts + env_vars override)

Work Classification

Both (library-first merge gate; workspace PR merges only after the library PR lands)

Branch Survey

Repository Current Branch Dirty?
./PyAutoLens main CLAUDE.md (unrelated — automated edit)
./autolens_workspace_test main notebooks/* (unrelated — automated /pre_build output)

Suggested branch: feature/point-source-jax-viz
Worktree root: ~/Code/PyAutoLabs-wt/point-source-jax-viz/ (created later by /start_library)

Implementation Steps

Library (PyAutoLens):

  1. autolens/point/model/analysis.py lines 37-46 — add **kwargs, to AnalysisPoint.__init__ parameter list.
  2. autolens/point/model/analysis.py line 80 — forward **kwargs in super().__init__(cosmology=cosmology, use_jax=use_jax, **kwargs).
  3. pytest test_autolens/point/ must pass (broader sweep too if affordable).

Workspace sub-step 1 — NumPy baseline (autolens_workspace_test):

  1. New file scripts/point_source/visualization.py:
    • Reuse scripts/point_source/simulators/point_source.py for the dataset.
    • Build a parametric mass + point-source model.
    • Construct al.AnalysisPoint(dataset=..., solver=..., use_jax=False, ...).
    • Call VisualizerPoint.visualize(...) (verify class name from autolens/point/model/visualizer.py).
    • Assert the expected PNG/subplot files land on disk.

Workspace sub-step 2 — JAX viz (image-plane only):

  1. New file scripts/point_source/visualization_jax.py:
    • from autofit.jax.pytrees import enable_pytrees, register_model + enable_pytrees() at module level.
    • Same model/dataset as sub-step 1 but fit_positions_cls=FitPositionsImagePairAll (or equivalent image-plane variant per jax_likelihood_functions/point_source/image_plane.py).
    • register_model(model) after model is built.
    • al.AnalysisPoint(..., use_jax=True, use_jax_for_visualization=True) — depends on the library fix.
    • Call VisualizerPoint.visualize(...) directly. No try/except wrapper.
    • Assert fit.png (or equivalent) lands on disk.

Workspace sub-step 3 — Live Nautilus jit-cached:

  1. New file scripts/point_source/modeling_visualization_jit.py:
    • Part 1 caching probe: build the model + analysis with linear profiles where the point-source model supports them; call analysis.fit_for_visualization(instance) twice, assert second is significantly faster.
    • Part 2 live Nautilus: same model, real (short) Nautilus run with n_live=50, n_like_max=1500, iterations_per_quick_update=500.
    • Explicit rmtree(output/scripts/point_source/images/modeling_visualization_jit/mge_linear/) before the Nautilus call so reruns don't silently resume from cached samples.csv (lesson from PR feat: autolens interferometer JAX visualization scripts #87).
    • Assert _jitted_fit_from is not None on the analysis post-search AND at least one fit.png produced.

Workspace env vars:

  1. config/build/env_vars.yaml:
    • Add pattern: "point_source/visualization_jax" → unset PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS.
    • Add pattern: "point_source/modeling_visualization_jit" → unset PYAUTO_DISABLE_JAX, PYAUTO_SMALL_DATASETS, PYAUTO_TEST_MODE, PYAUTO_FAST_PLOTS.

Source-plane feasibility gate (decide late in implementation):

  1. Probe whether FitPositionsSource + jax.jit(analysis.fit_from) succeeds. If it still raises (per scripts/CLAUDE.md L132), do NOT add a source-plane variant in this task — file a follow-up prompt. If it works, add scripts/point_source/visualization_jax_source_plane.py analogously.

Key Files

  • PyAutoLens/autolens/point/model/analysis.py — 2-line **kwargs fix
  • autolens_workspace_test/scripts/point_source/visualization.py (NEW)
  • autolens_workspace_test/scripts/point_source/visualization_jax.py (NEW)
  • autolens_workspace_test/scripts/point_source/modeling_visualization_jit.py (NEW)
  • autolens_workspace_test/config/build/env_vars.yaml (EDIT — 2 new override entries)

Reference patterns

Original Prompt

Click to expand starting prompt

(elided — full text in PyAutoPrompt/issued/jax_viz_point_source_coverage.md)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions