Skip to content

feat: jax_likelihood_functions/point_source parity with profiling coverage #51

@Jammy2211

Description

@Jammy2211

Overview

Extend autolens_workspace_test/scripts/jax_likelihood_functions/point_source/ from the single existing point.py test to cover both fit-positions modes profiled in autolens_workspace_developer/jax_profiling/point_source/ — image-plane (FitPositionsImagePairAll) and source-plane (FitPositionsSource) chi-squared. Both variants reuse the canonical simulated PointDataset produced by autolens_workspace_test/scripts/point_source/simulators/point_source.py so we do not fork the simulator.

Plan

  • Add image_plane.pyAnalysisPoint(fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True) + three-step pattern. Image-plane fitting JITs end-to-end today so this should pass cleanly.
  • Add source_plane.pyAnalysisPoint(fit_positions_cls=al.FitPositionsSource, use_jax=True) + three-step pattern, wrapping the Path A JIT in a try/except TracerArrayConversionError that prints a BLOCKER line pointing at the known Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation bug.
  • Both scripts use the should_simulate bootstrap to invoke the existing simulator — no new simulator file in the jax folder.
  • Both scripts assert the eager NumPy log_likelihood against a hardcoded regression constant (filled in on first run) so the NumPy side is guarded even when the source-plane JIT is blocked.
  • Update autolens_workspace_test/scripts/CLAUDE.md to list the two new scripts in the jax_likelihood_functions/ table.
Detailed implementation plan

Affected Repositories

  • autolens_workspace_test (primary)

Branch Survey

Repository Current Branch Dirty?
./autolens_workspace_test main one deleted tracked file (scripts/jax_likelihood_functions/imaging/nnls_invariance.py) — unrelated, pre-existing

Suggested branch: feature/jax-likelihood-point-source-parity

Implementation Steps

  1. Scaffolding boilerplate (shared by both new scripts)

    • Copy the should_simulate + subprocess.run([sys.executable, "scripts/point_source/simulators/point_source.py"], check=True) pattern from the existing jax_likelihood_functions/point_source/point.py.
    • Load the dataset via al.from_json / the canonical path used in point.py (inspect the existing simulator output path — likely dataset/point_source/simple/point_dict.json or similar).
    • Call enable_pytrees() + register_model(model) before Path A JIT.
  2. scripts/jax_likelihood_functions/point_source/image_plane.py

    • Build solver = al.PointSolver.for_grid(grid=..., pixel_scale_precision=0.001, magnification_threshold=0.1, xp=jnp).
    • analysis_jit = al.AnalysisPoint(dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True).
    • Three-step pattern on fitness._vmap(parameters) (hardcoded placeholder log-likelihood, filled on first run) → analysis_np.fit_from(instance)jax.jit(analysis_jit.fit_from)(instance) round-trip with rtol=1e-4.
    • Print PASS: jit(fit_from) round-trip matches NumPy scalar.
  3. scripts/jax_likelihood_functions/point_source/source_plane.py

    • Same scaffolding, but fit_positions_cls=al.FitPositionsSource.
    • Wrap Path A in try/except jax.errors.TracerArrayConversionError — on catch, print BLOCKER: source-plane JIT blocked by Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation — tracked in admin_jammy/prompt/issued/fit_point_pytree.md. and set full_pipeline_jits=False.
    • Still assert the eager NumPy fit.log_likelihood against a hardcoded constant for regression coverage.
    • Print PASS: jit(fit_from) round-trip matches NumPy scalar. only if the try succeeds; otherwise print the BLOCKER line.
  4. Update scripts/CLAUDE.md

    • Add two new rows under jax_likelihood_functions/ table:
      • point_source/image_plane.py | FitPositionsImagePairAll chi-squared (image-plane)
      • point_source/source_plane.py | FitPositionsSource chi-squared (source-plane, JIT blocked)

Key Files

  • autolens_workspace_test/scripts/jax_likelihood_functions/point_source/image_plane.py — new
  • autolens_workspace_test/scripts/jax_likelihood_functions/point_source/source_plane.py — new
  • autolens_workspace_test/scripts/jax_likelihood_functions/point_source/point.py — reference for dataset-load wiring
  • autolens_workspace_developer/jax_profiling/point_source/image_plane.py — reference for image-plane JIT wiring
  • autolens_workspace_developer/jax_profiling/point_source/source_plane.py — reference for source-plane BLOCKER pattern
  • autolens_workspace_test/scripts/point_source/simulators/point_source.py — canonical simulator reused via should_simulate
  • autolens_workspace_test/scripts/CLAUDE.md — table update

Scope boundary

  • Do not introduce a new simulator in jax_likelihood_functions/point_source/ — reuse the existing one.
  • Do not attempt to work around the source-plane xp-propagation blocker in the test script; the library fix is tracked separately in admin_jammy/prompt/issued/fit_point_pytree.md.
  • Do not modify autolens_workspace_developer/jax_profiling/point_source/ — that is the profiling reference, not the test surface.

Original Prompt

Click to expand starting prompt

Extend autolens_workspace_test/scripts/jax_likelihood_functions/point_source/ from the single
point.py to cover the two fit-positions modes profiled in
autolens_workspace_developer/jax_profiling/point_source/ — image-plane and source-plane
chi-squared — while reusing the existing simulated dataset in
autolens_workspace_test/scripts/point_source/simulators/point_source.py.

Baseline already shipped

  • autolens_workspace_test/scripts/jax_likelihood_functions/point_source/point.py — current
    single-variant jax_likelihood coverage.
  • autolens_workspace_test/scripts/point_source/simulators/point_source.py — fixed seeded
    PointDataset simulator (noise_seed=1) that is the canonical input for point-source tests.
  • autolens_workspace_developer/jax_profiling/point_source/{image_plane,source_plane}.py
    the two fit modes, with full three-tier numerical assertions and a documented source-plane
    blocker (Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation bug).

Reference scripts

  • Three-step pattern template: @autolens_workspace_test/scripts/jax_likelihood_functions/imaging/mge.py.
  • Fit-mode templates: @autolens_workspace_developer/jax_profiling/point_source/image_plane.py
    and source_plane.py — use their AnalysisPoint(..., fit_positions_cls=..., use_jax=True)
    wiring, PointSolver.for_grid(..., xp=jnp) and the ray-trace prefix pattern.
  • Dataset source: @autolens_workspace_test/scripts/point_source/simulators/point_source.py
    do not clone the simulator into jax_likelihood_functions/point_source/. Load the
    already-simulated dataset via the should_simulate bootstrap (same pattern as every other
    jax_likelihood_functions script), invoking the simulator in scripts/point_source/simulators/
    via subprocess.run.

Scripts to add in autolens_workspace_test/scripts/jax_likelihood_functions/point_source/

  1. image_plane.py

    • AnalysisPoint(dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True).
    • PointSolver.for_grid(grid=..., pixel_scale_precision=0.001, magnification_threshold=0.1, xp=jnp).
    • Three-step pattern: vmap assertion → Path A jax.jit(analysis.fit_from) round-trip →
      PASS: jit(fit_from) round-trip matches NumPy scalar.
    • Image-plane fitting is known to JIT end-to-end (the profiling script confirms), so this
      variant should pass cleanly once FitPointDataset pytree registration is in place.
  2. source_plane.py

    • AnalysisPoint(dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsSource, use_jax=True).
    • Three-step pattern, but wrap Path A (jax.jit(analysis.fit_from)) in a
      try/except jax.errors.TracerArrayConversionError: that prints a clear BLOCKER line and
      sets a full_pipeline_jits=False sentinel, mirroring
      jax_profiling/point_source/source_plane.py.
    • Still assert the eager NumPy path against a hardcoded log-likelihood regression
      constant so the NumPy side of the pipeline is guarded.
    • If / when the Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation fix
      lands, the script will JIT cleanly without modification.

Dataset reuse

Both scripts load from dataset/point_source/simple/point_dataset_positions_only.json (or
whatever filename the autolens_workspace_test simulator writes — confirm by reading the existing
simulator). The should_simulate bootstrap runs
scripts/point_source/simulators/point_source.py if the dataset is missing. Do not write a
new simulator under jax_likelihood_functions/point_source/
— reuse the existing one.

Pytree registration

Both scripts call enable_pytrees() + register_model(model) before the Path A JIT. FitPointDataset
registration is covered by the issued fit_point_pytree.md prompt — reference it as a dependency.
If that library-side work has not landed yet, Path A will fail with a pytree-registration error;
document and park rather than hand-register in the test scripts.

Deliverables

  1. Two new scripts in autolens_workspace_test/scripts/jax_likelihood_functions/point_source/:
    image_plane.py, source_plane.py.
  2. image_plane.py prints PASS: jit(fit_from) round-trip matches NumPy scalar.
  3. source_plane.py either prints the PASS line (if the upstream xp bug is fixed) or a clear
    BLOCKER line identifying the eager vs JIT divergence — same pattern as the profiling script.
  4. Update autolens_workspace_test/scripts/CLAUDE.md to list the two new scripts in the
    jax_likelihood_functions/ table.

Scope boundary

  • Do not introduce a new simulator — reuse scripts/point_source/simulators/point_source.py.
  • Do not work around the source-plane xp-propagation blocker in the test script. Document it
    and leave the fix for the library task (fit_point_pytree.md).
  • Do not change autolens_workspace_developer/jax_profiling/point_source/ — that is the profiling
    reference, not the test surface.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions