Overview
jax.jit(analysis.fit_from)(instance) crashes for any model using AdaptImages because
galaxy_image_dict is keyed by Galaxy instances, whose .id (and therefore hash) changes
when JAX reconstructs them via Model.instance_unflatten. The fix is to look up adapt images
by galaxy path tuple — a keying scheme AdaptImages already supports via
galaxy_name_image_dict — and re-enable three workspace_test scripts deferred from PR #364.
Plan
- Fix the Galaxy-identity mismatch that breaks
AdaptImages lookups after JAX unflatten in PyAutoGalaxy.
- Use the lookup pattern already supported by
AdaptImages: index by galaxy path tuple (stable across unflatten) instead of by Galaxy instance (auto-incremented .id).
- Add a non-JAX unit test that simulates the unflatten scenario and confirms the lookup still resolves the correct adapt image.
- (Stretch) Apply the same fix to PyAutoLens
to_inversion.py and remove the existing single-pixelated-galaxy fallback that masks the same bug.
- Re-port the three deferred scripts (
rectangular_mge.py, delaunay.py, delaunay_mge.py) from autolens_workspace_test into autogalaxy_workspace_test.
- Re-enable those scripts in
smoke_tests.txt, plus delaunay_mge.py commented out with the jax-0.7 regression note.
Detailed implementation plan
Affected Repositories
- PyAutoGalaxy (primary)
- autogalaxy_workspace_test
- PyAutoLens (stretch — same library PR)
Work Classification
Both — library work first, workspace re-port follows.
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoGalaxy |
main |
clean |
| ./PyAutoLens |
main |
clean |
| ./autogalaxy_workspace_test |
main |
clean |
Suggested branch: feature/adapt-images-pytree-fix
Worktree root: ~/Code/PyAutoLabs-wt/adapt-images-pytree-fix/ (created later by /start_library)
Implementation Steps
autogalaxy/galaxy/to_inversion.py:554–570 — replace by-instance lookup with a helper call.
autogalaxy/analysis/adapt_images/adapt_images.py — add image_for_galaxy(galaxy, instance) -> Optional[Array2D]: try galaxy_image_dict[galaxy] first, fall back to galaxy_name_image_dict[path].
- No change to pytree registration — path-tuple keying is the cleaner invariant.
- New non-JAX unit test in
test_autogalaxy/analysis/test_adapt_images.py exercising fresh-Galaxy-same-path lookup.
- (Stretch)
autolens/lens/to_inversion.py:272–290 — replace single-pixelated-galaxy fallback with the same helper.
- Re-port
rectangular_mge.py, delaunay.py, delaunay_mge.py from autolens_workspace_test references; restore adapt variant of rectangular.py.
- Re-enable the four scripts in
autogalaxy_workspace_test/smoke_tests.txt; add delaunay_mge.py commented with the jax-0.7 regression note.
Key Files
PyAutoGalaxy/autogalaxy/galaxy/to_inversion.py — lookup site
PyAutoGalaxy/autogalaxy/analysis/adapt_images/adapt_images.py — new helper
PyAutoGalaxy/test_autogalaxy/analysis/test_adapt_images.py — new unit test
PyAutoLens/autolens/lens/to_inversion.py — (stretch) drop fallback
autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/{rectangular,rectangular_mge,delaunay,delaunay_mge}.py
autogalaxy_workspace_test/smoke_tests.txt
Original Prompt
Click to expand starting prompt
Fix AdaptImages.galaxy_image_dict Galaxy-identity mismatch across jax.jit boundary in
@PyAutoGalaxy, and re-enable the three autogalaxy_workspace_test scripts that this blocks.
Problem
When jax.jit(analysis.fit_from)(instance) returns a FitImaging via the pytree registration
added in PyAutoGalaxy PR #364, accessing fit.log_likelihood post-unflatten fails for any model
that uses AdaptImages:
AttributeError: 'NoneType' object has no attribute 'array'
File .../rectangular_adapt_image.py", line 93, in mesh_weight_map_from
mesh_weight_map = adapt_data.array
Root cause
FitImaging is registered with no_flatten=("dataset", "adapt_images", "settings"), so
adapt_images rides across the pytree boundary as aux — its galaxy_image_dict keys are the
trace-time ag.Galaxy instances. self.galaxies is registered as a pytree (dynamic), so
post-unflatten it contains fresh Galaxy instances built via autofit.Model.instance_unflatten
→ self.cls(*constructor_arguments), each with a new .id. hash(galaxy) returns int(self.id),
so the fresh Galaxy doesn't match any key in adapt_images.galaxy_image_dict. The lookup at
PyAutoGalaxy/autogalaxy/galaxy/to_inversion.py:555 raises KeyError → adapt_galaxy_image = None
→ mesh.mesh_weight_map_from(adapt_data=None) blows up.
The analogous fix on the autolens side that solved a similar dict-keyed-by-instance problem is
tracked at @admin_jammy/prompt/autolens/linear_light_profile_intensity_dict_pytree.md.
Note that autolens's jax_likelihood_functions/imaging/rectangular.py currently passes in
autolens_workspace_test despite apparently having the same Galaxy-identity issue — worth checking
what autolens does differently (e.g. a shared fix in autoarray / autofit, or a subtly different
FitImaging inversion path that bypasses the galaxy_image_dict lookup for the adapt mesh). That
diff may reveal the minimal fix for autogalaxy, or a broader pattern that should be lifted into
autoarray.
Scripts blocked
From @autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/, these were deferred in
the initial task 3/9 ship (PyAutoGalaxy PR #364, workspace PR on autogalaxy_workspace_test):
rectangular_mge.py — MGE bulge + ag.mesh.RectangularAdaptImage + ag.reg.Adapt
delaunay.py — ag.mesh.Delaunay + ag.image_mesh.Hilbert (or Overlay, which still wires
the image-plane mesh grid via adapt_images.galaxy_name_image_plane_mesh_grid_dict)
delaunay_mge.py — MGE bulge + Delaunay
The initial ship used ag.mesh.RectangularUniform + ag.reg.Constant for rectangular.py (no
adapt dependency), which does pass. After this fix lands, re-port the three scripts above using
the proper adapt-image autolens references at
@autolens_workspace_test/scripts/jax_likelihood_functions/imaging/{rectangular,rectangular_mge, delaunay,delaunay_mge}.py.
Deliverables
- PyAutoGalaxy library fix for the Galaxy-identity issue across the JIT boundary.
Candidate approaches (pick whichever is cleanest):
- Key
AdaptImages.galaxy_image_dict by the galaxy's path-tuple (e.g. ('galaxies', 'galaxy'))
instead of the Galaxy instance — stable across unflatten.
- Look up by
galaxy.id via an identity map that is rebuilt during fit_from (not carried as
aux).
- Register
Galaxy with a custom pytree that preserves .id through unflatten so hashes match.
- Move the adapt-image lookup inside
fit_from so it runs during tracing (before the pytree
boundary) and stores adapt_galaxy_image on the mapper directly rather than looking it up
lazily.
- Unit test in
test_autogalaxy/ exercising the fix without importing JAX (follow the
numpy-only unit test convention — cross-xp checks live in workspace_test).
- Re-port the three deferred scripts into autogalaxy_workspace_test using the autolens
references, and re-enable them in smoke_tests.txt alongside the existing jax_likelihood_
functions/imaging/ entries.
- Add
jax_likelihood_functions/imaging/delaunay_mge.py commented out in smoke_tests.txt with
the exact jax-0.7 regression comment from autolens's smoke_tests.txt.
Dependencies
Umbrella
Follow-up from PyAutoLabs/autogalaxy_workspace_test#8 (epic #5 task 3/9). Same issue may re-surface
in task 4/9 (jax_likelihood_interferometer) and task 5/9 (jax_likelihood_multi) — either fix
once here, or carry the same deferral pattern into those tasks.
Overview
jax.jit(analysis.fit_from)(instance)crashes for any model usingAdaptImagesbecausegalaxy_image_dictis keyed byGalaxyinstances, whose.id(and therefore hash) changeswhen JAX reconstructs them via
Model.instance_unflatten. The fix is to look up adapt imagesby galaxy path tuple — a keying scheme
AdaptImagesalready supports viagalaxy_name_image_dict— and re-enable three workspace_test scripts deferred from PR #364.Plan
AdaptImageslookups after JAX unflatten in PyAutoGalaxy.AdaptImages: index by galaxy path tuple (stable across unflatten) instead of by Galaxy instance (auto-incremented.id).to_inversion.pyand remove the existing single-pixelated-galaxy fallback that masks the same bug.rectangular_mge.py,delaunay.py,delaunay_mge.py) fromautolens_workspace_testintoautogalaxy_workspace_test.smoke_tests.txt, plusdelaunay_mge.pycommented out with the jax-0.7 regression note.Detailed implementation plan
Affected Repositories
Work Classification
Both — library work first, workspace re-port follows.
Branch Survey
Suggested branch:
feature/adapt-images-pytree-fixWorktree root:
~/Code/PyAutoLabs-wt/adapt-images-pytree-fix/(created later by/start_library)Implementation Steps
autogalaxy/galaxy/to_inversion.py:554–570— replace by-instance lookup with a helper call.autogalaxy/analysis/adapt_images/adapt_images.py— addimage_for_galaxy(galaxy, instance) -> Optional[Array2D]: trygalaxy_image_dict[galaxy]first, fall back togalaxy_name_image_dict[path].test_autogalaxy/analysis/test_adapt_images.pyexercising fresh-Galaxy-same-path lookup.autolens/lens/to_inversion.py:272–290— replace single-pixelated-galaxy fallback with the same helper.rectangular_mge.py,delaunay.py,delaunay_mge.pyfrom autolens_workspace_test references; restore adapt variant ofrectangular.py.autogalaxy_workspace_test/smoke_tests.txt; adddelaunay_mge.pycommented with the jax-0.7 regression note.Key Files
PyAutoGalaxy/autogalaxy/galaxy/to_inversion.py— lookup sitePyAutoGalaxy/autogalaxy/analysis/adapt_images/adapt_images.py— new helperPyAutoGalaxy/test_autogalaxy/analysis/test_adapt_images.py— new unit testPyAutoLens/autolens/lens/to_inversion.py— (stretch) drop fallbackautogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/{rectangular,rectangular_mge,delaunay,delaunay_mge}.pyautogalaxy_workspace_test/smoke_tests.txtOriginal Prompt
Click to expand starting prompt
Fix
AdaptImages.galaxy_image_dictGalaxy-identity mismatch acrossjax.jitboundary in@PyAutoGalaxy, and re-enable the three autogalaxy_workspace_test scripts that this blocks.
Problem
When
jax.jit(analysis.fit_from)(instance)returns aFitImagingvia the pytree registrationadded in PyAutoGalaxy PR #364, accessing
fit.log_likelihoodpost-unflatten fails for any modelthat uses
AdaptImages:Root cause
FitImagingis registered withno_flatten=("dataset", "adapt_images", "settings"), soadapt_imagesrides across the pytree boundary as aux — itsgalaxy_image_dictkeys are thetrace-time
ag.Galaxyinstances.self.galaxiesis registered as a pytree (dynamic), sopost-unflatten it contains fresh
Galaxyinstances built viaautofit.Model.instance_unflatten→
self.cls(*constructor_arguments), each with a new.id.hash(galaxy)returnsint(self.id),so the fresh Galaxy doesn't match any key in
adapt_images.galaxy_image_dict. The lookup atPyAutoGalaxy/autogalaxy/galaxy/to_inversion.py:555raisesKeyError→adapt_galaxy_image = None→
mesh.mesh_weight_map_from(adapt_data=None)blows up.The analogous fix on the autolens side that solved a similar dict-keyed-by-instance problem is
tracked at
@admin_jammy/prompt/autolens/linear_light_profile_intensity_dict_pytree.md.Note that autolens's jax_likelihood_functions/imaging/rectangular.py currently passes in
autolens_workspace_test despite apparently having the same Galaxy-identity issue — worth checking
what autolens does differently (e.g. a shared fix in autoarray / autofit, or a subtly different
FitImaging inversion path that bypasses the
galaxy_image_dictlookup for the adapt mesh). Thatdiff may reveal the minimal fix for autogalaxy, or a broader pattern that should be lifted into
autoarray.
Scripts blocked
From @autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/, these were deferred in
the initial task 3/9 ship (PyAutoGalaxy PR #364, workspace PR on autogalaxy_workspace_test):
rectangular_mge.py— MGE bulge +ag.mesh.RectangularAdaptImage+ag.reg.Adaptdelaunay.py—ag.mesh.Delaunay+ag.image_mesh.Hilbert(orOverlay, which still wiresthe image-plane mesh grid via
adapt_images.galaxy_name_image_plane_mesh_grid_dict)delaunay_mge.py— MGE bulge + DelaunayThe initial ship used
ag.mesh.RectangularUniform+ag.reg.Constantforrectangular.py(noadapt dependency), which does pass. After this fix lands, re-port the three scripts above using
the proper adapt-image autolens references at
@autolens_workspace_test/scripts/jax_likelihood_functions/imaging/{rectangular,rectangular_mge, delaunay,delaunay_mge}.py.Deliverables
Candidate approaches (pick whichever is cleanest):
AdaptImages.galaxy_image_dictby the galaxy's path-tuple (e.g.('galaxies', 'galaxy'))instead of the Galaxy instance — stable across unflatten.
galaxy.idvia an identity map that is rebuilt duringfit_from(not carried asaux).
Galaxywith a custom pytree that preserves.idthrough unflatten so hashes match.fit_fromso it runs during tracing (before the pytreeboundary) and stores
adapt_galaxy_imageon the mapper directly rather than looking it uplazily.
test_autogalaxy/exercising the fix without importing JAX (follow thenumpy-only unit test convention — cross-xp checks live in workspace_test).
references, and re-enable them in
smoke_tests.txtalongside the existing jax_likelihood_functions/imaging/ entries.
jax_likelihood_functions/imaging/delaunay_mge.pycommented out insmoke_tests.txtwiththe exact jax-0.7 regression comment from autolens's smoke_tests.txt.
Dependencies
update in the same library PR.
Umbrella
Follow-up from PyAutoLabs/autogalaxy_workspace_test#8 (epic #5 task 3/9). Same issue may re-surface
in task 4/9 (
jax_likelihood_interferometer) and task 5/9 (jax_likelihood_multi) — either fixonce here, or carry the same deferral pattern into those tasks.