Overview
Extend autolens_workspace_test/scripts/jax_likelihood_functions/point_source/ from the single existing point.py test to cover both fit-positions modes profiled in autolens_workspace_developer/jax_profiling/point_source/ — image-plane (FitPositionsImagePairAll) and source-plane (FitPositionsSource) chi-squared. Both variants reuse the canonical simulated PointDataset produced by autolens_workspace_test/scripts/point_source/simulators/point_source.py so we do not fork the simulator.
Plan
- Add
image_plane.py — AnalysisPoint(fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True) + three-step pattern. Image-plane fitting JITs end-to-end today so this should pass cleanly.
- Add
source_plane.py — AnalysisPoint(fit_positions_cls=al.FitPositionsSource, use_jax=True) + three-step pattern, wrapping the Path A JIT in a try/except TracerArrayConversionError that prints a BLOCKER line pointing at the known Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation bug.
- Both scripts use the
should_simulate bootstrap to invoke the existing simulator — no new simulator file in the jax folder.
- Both scripts assert the eager NumPy
log_likelihood against a hardcoded regression constant (filled in on first run) so the NumPy side is guarded even when the source-plane JIT is blocked.
- Update
autolens_workspace_test/scripts/CLAUDE.md to list the two new scripts in the jax_likelihood_functions/ table.
Detailed implementation plan
Affected Repositories
- autolens_workspace_test (primary)
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./autolens_workspace_test |
main |
one deleted tracked file (scripts/jax_likelihood_functions/imaging/nnls_invariance.py) — unrelated, pre-existing |
Suggested branch: feature/jax-likelihood-point-source-parity
Implementation Steps
-
Scaffolding boilerplate (shared by both new scripts)
- Copy the
should_simulate + subprocess.run([sys.executable, "scripts/point_source/simulators/point_source.py"], check=True) pattern from the existing jax_likelihood_functions/point_source/point.py.
- Load the dataset via
al.from_json / the canonical path used in point.py (inspect the existing simulator output path — likely dataset/point_source/simple/point_dict.json or similar).
- Call
enable_pytrees() + register_model(model) before Path A JIT.
-
scripts/jax_likelihood_functions/point_source/image_plane.py
- Build
solver = al.PointSolver.for_grid(grid=..., pixel_scale_precision=0.001, magnification_threshold=0.1, xp=jnp).
analysis_jit = al.AnalysisPoint(dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True).
- Three-step pattern on
fitness._vmap(parameters) (hardcoded placeholder log-likelihood, filled on first run) → analysis_np.fit_from(instance) → jax.jit(analysis_jit.fit_from)(instance) round-trip with rtol=1e-4.
- Print
PASS: jit(fit_from) round-trip matches NumPy scalar.
-
scripts/jax_likelihood_functions/point_source/source_plane.py
- Same scaffolding, but
fit_positions_cls=al.FitPositionsSource.
- Wrap Path A in
try/except jax.errors.TracerArrayConversionError — on catch, print BLOCKER: source-plane JIT blocked by Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation — tracked in admin_jammy/prompt/issued/fit_point_pytree.md. and set full_pipeline_jits=False.
- Still assert the eager NumPy
fit.log_likelihood against a hardcoded constant for regression coverage.
- Print
PASS: jit(fit_from) round-trip matches NumPy scalar. only if the try succeeds; otherwise print the BLOCKER line.
-
Update scripts/CLAUDE.md
- Add two new rows under
jax_likelihood_functions/ table:
point_source/image_plane.py | FitPositionsImagePairAll chi-squared (image-plane)
point_source/source_plane.py | FitPositionsSource chi-squared (source-plane, JIT blocked)
Key Files
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/image_plane.py — new
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/source_plane.py — new
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/point.py — reference for dataset-load wiring
autolens_workspace_developer/jax_profiling/point_source/image_plane.py — reference for image-plane JIT wiring
autolens_workspace_developer/jax_profiling/point_source/source_plane.py — reference for source-plane BLOCKER pattern
autolens_workspace_test/scripts/point_source/simulators/point_source.py — canonical simulator reused via should_simulate
autolens_workspace_test/scripts/CLAUDE.md — table update
Scope boundary
- Do not introduce a new simulator in
jax_likelihood_functions/point_source/ — reuse the existing one.
- Do not attempt to work around the source-plane xp-propagation blocker in the test script; the library fix is tracked separately in
admin_jammy/prompt/issued/fit_point_pytree.md.
- Do not modify
autolens_workspace_developer/jax_profiling/point_source/ — that is the profiling reference, not the test surface.
Original Prompt
Click to expand starting prompt
Extend autolens_workspace_test/scripts/jax_likelihood_functions/point_source/ from the single
point.py to cover the two fit-positions modes profiled in
autolens_workspace_developer/jax_profiling/point_source/ — image-plane and source-plane
chi-squared — while reusing the existing simulated dataset in
autolens_workspace_test/scripts/point_source/simulators/point_source.py.
Baseline already shipped
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/point.py — current
single-variant jax_likelihood coverage.
autolens_workspace_test/scripts/point_source/simulators/point_source.py — fixed seeded
PointDataset simulator (noise_seed=1) that is the canonical input for point-source tests.
autolens_workspace_developer/jax_profiling/point_source/{image_plane,source_plane}.py —
the two fit modes, with full three-tier numerical assertions and a documented source-plane
blocker (Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation bug).
Reference scripts
- Three-step pattern template:
@autolens_workspace_test/scripts/jax_likelihood_functions/imaging/mge.py.
- Fit-mode templates:
@autolens_workspace_developer/jax_profiling/point_source/image_plane.py
and source_plane.py — use their AnalysisPoint(..., fit_positions_cls=..., use_jax=True)
wiring, PointSolver.for_grid(..., xp=jnp) and the ray-trace prefix pattern.
- Dataset source:
@autolens_workspace_test/scripts/point_source/simulators/point_source.py
— do not clone the simulator into jax_likelihood_functions/point_source/. Load the
already-simulated dataset via the should_simulate bootstrap (same pattern as every other
jax_likelihood_functions script), invoking the simulator in scripts/point_source/simulators/
via subprocess.run.
Scripts to add in autolens_workspace_test/scripts/jax_likelihood_functions/point_source/
-
image_plane.py
AnalysisPoint(dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True).
PointSolver.for_grid(grid=..., pixel_scale_precision=0.001, magnification_threshold=0.1, xp=jnp).
- Three-step pattern: vmap assertion → Path A
jax.jit(analysis.fit_from) round-trip →
PASS: jit(fit_from) round-trip matches NumPy scalar.
- Image-plane fitting is known to JIT end-to-end (the profiling script confirms), so this
variant should pass cleanly once FitPointDataset pytree registration is in place.
-
source_plane.py
AnalysisPoint(dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsSource, use_jax=True).
- Three-step pattern, but wrap Path A (
jax.jit(analysis.fit_from)) in a
try/except jax.errors.TracerArrayConversionError: that prints a clear BLOCKER line and
sets a full_pipeline_jits=False sentinel, mirroring
jax_profiling/point_source/source_plane.py.
- Still assert the eager NumPy path against a hardcoded log-likelihood regression
constant so the NumPy side of the pipeline is guarded.
- If / when the
Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation fix
lands, the script will JIT cleanly without modification.
Dataset reuse
Both scripts load from dataset/point_source/simple/point_dataset_positions_only.json (or
whatever filename the autolens_workspace_test simulator writes — confirm by reading the existing
simulator). The should_simulate bootstrap runs
scripts/point_source/simulators/point_source.py if the dataset is missing. Do not write a
new simulator under jax_likelihood_functions/point_source/ — reuse the existing one.
Pytree registration
Both scripts call enable_pytrees() + register_model(model) before the Path A JIT. FitPointDataset
registration is covered by the issued fit_point_pytree.md prompt — reference it as a dependency.
If that library-side work has not landed yet, Path A will fail with a pytree-registration error;
document and park rather than hand-register in the test scripts.
Deliverables
- Two new scripts in
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/:
image_plane.py, source_plane.py.
image_plane.py prints PASS: jit(fit_from) round-trip matches NumPy scalar.
source_plane.py either prints the PASS line (if the upstream xp bug is fixed) or a clear
BLOCKER line identifying the eager vs JIT divergence — same pattern as the profiling script.
- Update
autolens_workspace_test/scripts/CLAUDE.md to list the two new scripts in the
jax_likelihood_functions/ table.
Scope boundary
- Do not introduce a new simulator — reuse
scripts/point_source/simulators/point_source.py.
- Do not work around the source-plane
xp-propagation blocker in the test script. Document it
and leave the fix for the library task (fit_point_pytree.md).
- Do not change
autolens_workspace_developer/jax_profiling/point_source/ — that is the profiling
reference, not the test surface.
Overview
Extend
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/from the single existingpoint.pytest to cover both fit-positions modes profiled inautolens_workspace_developer/jax_profiling/point_source/— image-plane (FitPositionsImagePairAll) and source-plane (FitPositionsSource) chi-squared. Both variants reuse the canonical simulatedPointDatasetproduced byautolens_workspace_test/scripts/point_source/simulators/point_source.pyso we do not fork the simulator.Plan
image_plane.py—AnalysisPoint(fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True)+ three-step pattern. Image-plane fitting JITs end-to-end today so this should pass cleanly.source_plane.py—AnalysisPoint(fit_positions_cls=al.FitPositionsSource, use_jax=True)+ three-step pattern, wrapping the Path A JIT in atry/except TracerArrayConversionErrorthat prints a BLOCKER line pointing at the knownGrid2DIrregular.grid_2d_via_deflection_grid_fromxp-propagation bug.should_simulatebootstrap to invoke the existing simulator — no new simulator file in the jax folder.log_likelihoodagainst a hardcoded regression constant (filled in on first run) so the NumPy side is guarded even when the source-plane JIT is blocked.autolens_workspace_test/scripts/CLAUDE.mdto list the two new scripts in thejax_likelihood_functions/table.Detailed implementation plan
Affected Repositories
Branch Survey
scripts/jax_likelihood_functions/imaging/nnls_invariance.py) — unrelated, pre-existingSuggested branch:
feature/jax-likelihood-point-source-parityImplementation Steps
Scaffolding boilerplate (shared by both new scripts)
should_simulate+subprocess.run([sys.executable, "scripts/point_source/simulators/point_source.py"], check=True)pattern from the existingjax_likelihood_functions/point_source/point.py.al.from_json/ the canonical path used inpoint.py(inspect the existing simulator output path — likelydataset/point_source/simple/point_dict.jsonor similar).enable_pytrees()+register_model(model)before Path A JIT.scripts/jax_likelihood_functions/point_source/image_plane.pysolver = al.PointSolver.for_grid(grid=..., pixel_scale_precision=0.001, magnification_threshold=0.1, xp=jnp).analysis_jit = al.AnalysisPoint(dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True).fitness._vmap(parameters)(hardcoded placeholder log-likelihood, filled on first run) →analysis_np.fit_from(instance)→jax.jit(analysis_jit.fit_from)(instance)round-trip withrtol=1e-4.PASS: jit(fit_from) round-trip matches NumPy scalar.scripts/jax_likelihood_functions/point_source/source_plane.pyfit_positions_cls=al.FitPositionsSource.try/except jax.errors.TracerArrayConversionError— on catch, printBLOCKER: source-plane JIT blocked by Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation — tracked in admin_jammy/prompt/issued/fit_point_pytree.md.and setfull_pipeline_jits=False.fit.log_likelihoodagainst a hardcoded constant for regression coverage.PASS: jit(fit_from) round-trip matches NumPy scalar.only if the try succeeds; otherwise print the BLOCKER line.Update
scripts/CLAUDE.mdjax_likelihood_functions/table:point_source/image_plane.py | FitPositionsImagePairAll chi-squared (image-plane)point_source/source_plane.py | FitPositionsSource chi-squared (source-plane, JIT blocked)Key Files
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/image_plane.py— newautolens_workspace_test/scripts/jax_likelihood_functions/point_source/source_plane.py— newautolens_workspace_test/scripts/jax_likelihood_functions/point_source/point.py— reference for dataset-load wiringautolens_workspace_developer/jax_profiling/point_source/image_plane.py— reference for image-plane JIT wiringautolens_workspace_developer/jax_profiling/point_source/source_plane.py— reference for source-plane BLOCKER patternautolens_workspace_test/scripts/point_source/simulators/point_source.py— canonical simulator reused viashould_simulateautolens_workspace_test/scripts/CLAUDE.md— table updateScope boundary
jax_likelihood_functions/point_source/— reuse the existing one.admin_jammy/prompt/issued/fit_point_pytree.md.autolens_workspace_developer/jax_profiling/point_source/— that is the profiling reference, not the test surface.Original Prompt
Click to expand starting prompt
Extend
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/from the singlepoint.pyto cover the two fit-positions modes profiled inautolens_workspace_developer/jax_profiling/point_source/— image-plane and source-planechi-squared — while reusing the existing simulated dataset in
autolens_workspace_test/scripts/point_source/simulators/point_source.py.Baseline already shipped
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/point.py— currentsingle-variant jax_likelihood coverage.
autolens_workspace_test/scripts/point_source/simulators/point_source.py— fixed seededPointDatasetsimulator (noise_seed=1) that is the canonical input for point-source tests.autolens_workspace_developer/jax_profiling/point_source/{image_plane,source_plane}.py—the two fit modes, with full three-tier numerical assertions and a documented source-plane
blocker (
Grid2DIrregular.grid_2d_via_deflection_grid_fromxp-propagation bug).Reference scripts
@autolens_workspace_test/scripts/jax_likelihood_functions/imaging/mge.py.@autolens_workspace_developer/jax_profiling/point_source/image_plane.pyand
source_plane.py— use theirAnalysisPoint(..., fit_positions_cls=..., use_jax=True)wiring,
PointSolver.for_grid(..., xp=jnp)and the ray-trace prefix pattern.@autolens_workspace_test/scripts/point_source/simulators/point_source.py— do not clone the simulator into
jax_likelihood_functions/point_source/. Load thealready-simulated dataset via the
should_simulatebootstrap (same pattern as every otherjax_likelihood_functions script), invoking the simulator in
scripts/point_source/simulators/via
subprocess.run.Scripts to add in
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/image_plane.pyAnalysisPoint(dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsImagePairAll, use_jax=True).PointSolver.for_grid(grid=..., pixel_scale_precision=0.001, magnification_threshold=0.1, xp=jnp).jax.jit(analysis.fit_from)round-trip →PASS: jit(fit_from) round-trip matches NumPy scalar.variant should pass cleanly once
FitPointDatasetpytree registration is in place.source_plane.pyAnalysisPoint(dataset=dataset, solver=solver, fit_positions_cls=al.FitPositionsSource, use_jax=True).jax.jit(analysis.fit_from)) in atry/except jax.errors.TracerArrayConversionError:that prints a clear BLOCKER line andsets a
full_pipeline_jits=Falsesentinel, mirroringjax_profiling/point_source/source_plane.py.constant so the NumPy side of the pipeline is guarded.
Grid2DIrregular.grid_2d_via_deflection_grid_fromxp-propagation fixlands, the script will JIT cleanly without modification.
Dataset reuse
Both scripts load from
dataset/point_source/simple/point_dataset_positions_only.json(orwhatever filename the autolens_workspace_test simulator writes — confirm by reading the existing
simulator). The
should_simulatebootstrap runsscripts/point_source/simulators/point_source.pyif the dataset is missing. Do not write anew simulator under
jax_likelihood_functions/point_source/— reuse the existing one.Pytree registration
Both scripts call
enable_pytrees()+register_model(model)before the Path A JIT.FitPointDatasetregistration is covered by the issued
fit_point_pytree.mdprompt — reference it as a dependency.If that library-side work has not landed yet, Path A will fail with a pytree-registration error;
document and park rather than hand-register in the test scripts.
Deliverables
autolens_workspace_test/scripts/jax_likelihood_functions/point_source/:image_plane.py,source_plane.py.image_plane.pyprintsPASS: jit(fit_from) round-trip matches NumPy scalar.source_plane.pyeither prints the PASS line (if the upstreamxpbug is fixed) or a clearBLOCKER line identifying the eager vs JIT divergence — same pattern as the profiling script.
autolens_workspace_test/scripts/CLAUDE.mdto list the two new scripts in thejax_likelihood_functions/table.Scope boundary
scripts/point_source/simulators/point_source.py.xp-propagation blocker in the test script. Document itand leave the fix for the library task (
fit_point_pytree.md).autolens_workspace_developer/jax_profiling/point_source/— that is the profilingreference, not the test surface.