Overview
Adds scripts/jax_likelihood_functions/imaging/ to autogalaxy_workspace_test — autogalaxy ports
of the 8 autolens_workspace_test JAX-likelihood imaging scripts (excluding the two *_dspl.py
lens-specific variants). Unblocks the jax.jit(analysis.fit_from) scalar round-trip on
autogalaxy's imaging path.
Prerequisite library work: autogalaxy.imaging.model.analysis.AnalysisImaging has no
_register_fit_imaging_pytrees method today, so jax.jit(fit_from) cannot flatten its
FitImaging return value. A library PR on PyAutoGalaxy is shipped first, mirroring
autolens's implementation.
Part of epic #5 (task 3/9).
Plan
- Library PR on PyAutoGalaxy: add
_register_fit_imaging_pytrees to AnalysisImaging, modelled
verbatim on autolens.imaging.model.analysis.AnalysisImaging._register_fit_imaging_pytrees.
Register autogalaxy's FitImaging (with no_flatten=("dataset", "adapt_images", "settings")),
the shared DatasetModel (idempotent — autolens already registers it), and the galaxies
container that fit_from passes through (ag.Galaxies if used, or the raw list — verify at
implementation). Call the method from fit_from under the existing use_jax gate.
- Merge the library PR, pull, then port 8 scripts to
autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/.
- Each script follows the three-step pattern (NumPy baseline →
jax.jit(fit_from) → scalar
log_likelihood match) and prints PASS: jit(fit_from) round-trip matches NumPy scalar..
- Append scripts to
smoke_tests.txt. Disable delaunay_mge.py with the jax-0.7 regression
comment matching the autolens entry.
- Workspace PR. Library-first merge gate enforced by
/ship_workspace.
Detailed implementation plan
Affected Repositories
- autogalaxy_workspace_test (primary — this issue)
- PyAutoGalaxy (library PR, separate issue + PR)
Work Classification
Both (library + workspace). Library ships first.
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoGalaxy |
main |
clean |
| ./autogalaxy_workspace_test |
main |
clean |
Suggested branch: feature/autogalaxy-wst-jax-lh-imaging
Task name: autogalaxy-wst-jax-lh-imaging
Worktree root: ~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-lh-imaging/ (created later by /start_library)
Library Implementation Steps (PyAutoGalaxy — ships first)
autogalaxy/imaging/model/analysis.py:
- Add a
_register_fit_imaging_pytrees @staticmethod modelled exactly on the autolens
equivalent at autolens/imaging/model/analysis.py:131. Lazy-import register_instance_pytree,
DatasetModel, and autogalaxy's FitImaging + Galaxies inside the method (same lazy-import
pattern autolens uses to avoid circular imports).
- Register:
FitImaging with no_flatten=("dataset", "adapt_images", "settings") — mirror of autolens.
DatasetModel (no no_flatten) — idempotent; register_instance_pytree checks a
registry set and skips if already registered.
- The galaxies container returned by
galaxies_via_instance_from — verify whether this is
a Galaxies instance (which subclasses List) or a plain list. If Galaxies, register
it (no no_flatten needed; no cosmology attribute). If plain list, JAX handles it
natively.
- Call the method from
fit_from under if self._use_jax: just like autolens's line 111–112.
test_autogalaxy/imaging/model/: add a minimal test that constructs AnalysisImaging(use_jax=True),
calls fit_from(instance) once, and asserts the registered classes appear in
autoarray.abstract_ndarray._pytree_registered_classes. Mirrors the autolens test pattern.
- Library PR on PyAutoGalaxy. CI green. Merge → pull → move to workspace.
Workspace Implementation Steps (autogalaxy_workspace_test — ships second)
For each of the 8 autolens scripts (simulator.py, lp.py, mge.py, mge_group.py,
rectangular.py, rectangular_mge.py, delaunay.py, delaunay_mge.py):
- Read the autolens reference, strip lens-specific constructs:
al.Tracer / al.Galaxy(redshift=0.5, …) / al.Galaxy(redshift=1.0, …) → a single plane
of ag.Galaxy (no source/lens split).
al.AnalysisImaging → ag.AnalysisImaging.
- Drop
deflections / mass_profiles / convergence uses where they have no galaxy-side
counterpart (the lp/MGE source/pixelization paths should all have direct analogues).
- Keep the three-step JAX pattern:
- NumPy:
fit_numpy = analysis.fit_from(instance) → fit_numpy.log_likelihood
- JAX JIT:
fit_jax = jax.jit(analysis.fit_from)(instance) → fit_jax.log_likelihood
- Assert scalar equality to 1e-6 (or whatever tolerance the autolens reference uses).
- Each script prints
PASS: jit(fit_from) round-trip matches NumPy scalar..
scripts/jax_likelihood_functions/__init__.py + scripts/jax_likelihood_functions/imaging/__init__.py.
- Append to
smoke_tests.txt with jax_likelihood_functions/imaging/delaunay_mge.py commented
out using the exact autolens comment (jax 0.7 regression, references
admin_jammy/prompt/build/smoke_workspace_fixes.md).
- Workspace PR with
## Upstream PR linking to the library PR.
Known Spawn-offs
If either of these surfaces during implementation, stop and /start_dev a separate library task:
- Linear light profile pytree:
linear_light_profile_intensity_dict_pytree identity issue
(counterpart of the autolens-side fix at admin_jammy/prompt/autolens/linear_light_profile_intensity_dict_pytree.md).
Only blocks fit_for_visualization, not this task's fit_from scalar round-trip.
- Per-profile pytree registration if any autogalaxy profile isn't registered yet (follow
pattern in admin_jammy/prompt/issued/fit_imaging_pytree_*.md).
Excluded from this task
rectangular_dspl.py, simulator_dspl.py — double-source-plane, lens-specific.
AnalysisInterferometer pytree registration — covered by task 4 (autogalaxy_workspace_test_jax_likelihood_interferometer.md), which will open its own PyAutoGalaxy library PR.
fit_for_visualization JIT path — out of scope for the scalar fit_from round-trip.
Key Files
Library (PyAutoGalaxy):
autogalaxy/imaging/model/analysis.py — add _register_fit_imaging_pytrees + call site.
test_autogalaxy/imaging/model/test_analysis.py (or nearest existing) — registration test.
Workspace (autogalaxy_workspace_test):
scripts/jax_likelihood_functions/__init__.py (new)
scripts/jax_likelihood_functions/imaging/__init__.py (new)
scripts/jax_likelihood_functions/imaging/{simulator,lp,mge,mge_group,rectangular,rectangular_mge,delaunay,delaunay_mge}.py (new)
smoke_tests.txt — append 8 scripts (delaunay_mge commented out).
Original Prompt
Click to expand starting prompt
Create scripts/jax_likelihood_functions/imaging/ in @autogalaxy_workspace_test with autogalaxy
ports of every autolens JAX-likelihood imaging script, excluding the *_dspl.py double-source-
plane variants (lens-specific, no autogalaxy analogue).
See admin_jammy/prompt/issued/autogalaxy_workspace_test_jax_likelihood_imaging.md for full deliverables.
Part of umbrella epic #5 (task 3/9).
Overview
Adds
scripts/jax_likelihood_functions/imaging/to autogalaxy_workspace_test — autogalaxy portsof the 8 autolens_workspace_test JAX-likelihood imaging scripts (excluding the two
*_dspl.pylens-specific variants). Unblocks the
jax.jit(analysis.fit_from)scalar round-trip onautogalaxy's imaging path.
Prerequisite library work:
autogalaxy.imaging.model.analysis.AnalysisImaginghas no_register_fit_imaging_pytreesmethod today, sojax.jit(fit_from)cannot flatten itsFitImagingreturn value. A library PR on PyAutoGalaxy is shipped first, mirroringautolens's implementation.
Part of epic #5 (task 3/9).
Plan
_register_fit_imaging_pytreestoAnalysisImaging, modelledverbatim on
autolens.imaging.model.analysis.AnalysisImaging._register_fit_imaging_pytrees.Register autogalaxy's
FitImaging(withno_flatten=("dataset", "adapt_images", "settings")),the shared
DatasetModel(idempotent — autolens already registers it), and the galaxiescontainer that
fit_frompasses through (ag.Galaxiesif used, or the raw list — verify atimplementation). Call the method from
fit_fromunder the existinguse_jaxgate.autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/.jax.jit(fit_from)→ scalarlog_likelihoodmatch) and printsPASS: jit(fit_from) round-trip matches NumPy scalar..smoke_tests.txt. Disabledelaunay_mge.pywith the jax-0.7 regressioncomment matching the autolens entry.
/ship_workspace.Detailed implementation plan
Affected Repositories
Work Classification
Both (library + workspace). Library ships first.
Branch Survey
Suggested branch:
feature/autogalaxy-wst-jax-lh-imagingTask name:
autogalaxy-wst-jax-lh-imagingWorktree root:
~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-lh-imaging/(created later by/start_library)Library Implementation Steps (PyAutoGalaxy — ships first)
autogalaxy/imaging/model/analysis.py:_register_fit_imaging_pytrees@staticmethodmodelled exactly on the autolensequivalent at
autolens/imaging/model/analysis.py:131. Lazy-importregister_instance_pytree,DatasetModel, and autogalaxy'sFitImaging+Galaxiesinside the method (same lazy-importpattern autolens uses to avoid circular imports).
FitImagingwithno_flatten=("dataset", "adapt_images", "settings")— mirror of autolens.DatasetModel(nono_flatten) — idempotent;register_instance_pytreechecks aregistry set and skips if already registered.
galaxies_via_instance_from— verify whether this isa
Galaxiesinstance (which subclassesList) or a plain list. IfGalaxies, registerit (no
no_flattenneeded; no cosmology attribute). If plain list, JAX handles itnatively.
fit_fromunderif self._use_jax:just like autolens's line 111–112.test_autogalaxy/imaging/model/: add a minimal test that constructsAnalysisImaging(use_jax=True),calls
fit_from(instance)once, and asserts the registered classes appear inautoarray.abstract_ndarray._pytree_registered_classes. Mirrors the autolens test pattern.Workspace Implementation Steps (autogalaxy_workspace_test — ships second)
For each of the 8 autolens scripts (
simulator.py,lp.py,mge.py,mge_group.py,rectangular.py,rectangular_mge.py,delaunay.py,delaunay_mge.py):al.Tracer/al.Galaxy(redshift=0.5, …)/al.Galaxy(redshift=1.0, …)→ a single planeof
ag.Galaxy(no source/lens split).al.AnalysisImaging→ag.AnalysisImaging.deflections/mass_profiles/convergenceuses where they have no galaxy-sidecounterpart (the lp/MGE source/pixelization paths should all have direct analogues).
fit_numpy = analysis.fit_from(instance)→fit_numpy.log_likelihoodfit_jax = jax.jit(analysis.fit_from)(instance)→fit_jax.log_likelihoodPASS: jit(fit_from) round-trip matches NumPy scalar..scripts/jax_likelihood_functions/__init__.py+scripts/jax_likelihood_functions/imaging/__init__.py.smoke_tests.txtwithjax_likelihood_functions/imaging/delaunay_mge.pycommentedout using the exact autolens comment (jax 0.7 regression, references
admin_jammy/prompt/build/smoke_workspace_fixes.md).## Upstream PRlinking to the library PR.Known Spawn-offs
If either of these surfaces during implementation, stop and
/start_deva separate library task:linear_light_profile_intensity_dict_pytreeidentity issue(counterpart of the autolens-side fix at
admin_jammy/prompt/autolens/linear_light_profile_intensity_dict_pytree.md).Only blocks
fit_for_visualization, not this task'sfit_fromscalar round-trip.pattern in
admin_jammy/prompt/issued/fit_imaging_pytree_*.md).Excluded from this task
rectangular_dspl.py,simulator_dspl.py— double-source-plane, lens-specific.AnalysisInterferometerpytree registration — covered by task 4 (autogalaxy_workspace_test_jax_likelihood_interferometer.md), which will open its own PyAutoGalaxy library PR.fit_for_visualizationJIT path — out of scope for the scalarfit_fromround-trip.Key Files
Library (PyAutoGalaxy):
autogalaxy/imaging/model/analysis.py— add_register_fit_imaging_pytrees+ call site.test_autogalaxy/imaging/model/test_analysis.py(or nearest existing) — registration test.Workspace (autogalaxy_workspace_test):
scripts/jax_likelihood_functions/__init__.py(new)scripts/jax_likelihood_functions/imaging/__init__.py(new)scripts/jax_likelihood_functions/imaging/{simulator,lp,mge,mge_group,rectangular,rectangular_mge,delaunay,delaunay_mge}.py(new)smoke_tests.txt— append 8 scripts (delaunay_mge commented out).Original Prompt
Click to expand starting prompt
Create
scripts/jax_likelihood_functions/imaging/in @autogalaxy_workspace_test with autogalaxyports of every autolens JAX-likelihood imaging script, excluding the
*_dspl.pydouble-source-plane variants (lens-specific, no autogalaxy analogue).
See
admin_jammy/prompt/issued/autogalaxy_workspace_test_jax_likelihood_imaging.mdfor full deliverables.Part of umbrella epic #5 (task 3/9).