Skip to content

fix: AdaptImages galaxy-identity mismatch across jax.jit boundary #369

@Jammy2211

Description

@Jammy2211

Overview

jax.jit(analysis.fit_from)(instance) crashes for any model using AdaptImages because
galaxy_image_dict is keyed by Galaxy instances, whose .id (and therefore hash) changes
when JAX reconstructs them via Model.instance_unflatten. The fix is to look up adapt images
by galaxy path tuple — a keying scheme AdaptImages already supports via
galaxy_name_image_dict — and re-enable three workspace_test scripts deferred from PR #364.

Plan

  • Fix the Galaxy-identity mismatch that breaks AdaptImages lookups after JAX unflatten in PyAutoGalaxy.
  • Use the lookup pattern already supported by AdaptImages: index by galaxy path tuple (stable across unflatten) instead of by Galaxy instance (auto-incremented .id).
  • Add a non-JAX unit test that simulates the unflatten scenario and confirms the lookup still resolves the correct adapt image.
  • (Stretch) Apply the same fix to PyAutoLens to_inversion.py and remove the existing single-pixelated-galaxy fallback that masks the same bug.
  • Re-port the three deferred scripts (rectangular_mge.py, delaunay.py, delaunay_mge.py) from autolens_workspace_test into autogalaxy_workspace_test.
  • Re-enable those scripts in smoke_tests.txt, plus delaunay_mge.py commented out with the jax-0.7 regression note.
Detailed implementation plan

Affected Repositories

  • PyAutoGalaxy (primary)
  • autogalaxy_workspace_test
  • PyAutoLens (stretch — same library PR)

Work Classification

Both — library work first, workspace re-port follows.

Branch Survey

Repository Current Branch Dirty?
./PyAutoGalaxy main clean
./PyAutoLens main clean
./autogalaxy_workspace_test main clean

Suggested branch: feature/adapt-images-pytree-fix
Worktree root: ~/Code/PyAutoLabs-wt/adapt-images-pytree-fix/ (created later by /start_library)

Implementation Steps

  1. autogalaxy/galaxy/to_inversion.py:554–570 — replace by-instance lookup with a helper call.
  2. autogalaxy/analysis/adapt_images/adapt_images.py — add image_for_galaxy(galaxy, instance) -> Optional[Array2D]: try galaxy_image_dict[galaxy] first, fall back to galaxy_name_image_dict[path].
  3. No change to pytree registration — path-tuple keying is the cleaner invariant.
  4. New non-JAX unit test in test_autogalaxy/analysis/test_adapt_images.py exercising fresh-Galaxy-same-path lookup.
  5. (Stretch) autolens/lens/to_inversion.py:272–290 — replace single-pixelated-galaxy fallback with the same helper.
  6. Re-port rectangular_mge.py, delaunay.py, delaunay_mge.py from autolens_workspace_test references; restore adapt variant of rectangular.py.
  7. Re-enable the four scripts in autogalaxy_workspace_test/smoke_tests.txt; add delaunay_mge.py commented with the jax-0.7 regression note.

Key Files

  • PyAutoGalaxy/autogalaxy/galaxy/to_inversion.py — lookup site
  • PyAutoGalaxy/autogalaxy/analysis/adapt_images/adapt_images.py — new helper
  • PyAutoGalaxy/test_autogalaxy/analysis/test_adapt_images.py — new unit test
  • PyAutoLens/autolens/lens/to_inversion.py — (stretch) drop fallback
  • autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/{rectangular,rectangular_mge,delaunay,delaunay_mge}.py
  • autogalaxy_workspace_test/smoke_tests.txt

Original Prompt

Click to expand starting prompt

Fix AdaptImages.galaxy_image_dict Galaxy-identity mismatch across jax.jit boundary in
@PyAutoGalaxy, and re-enable the three autogalaxy_workspace_test scripts that this blocks.

Problem

When jax.jit(analysis.fit_from)(instance) returns a FitImaging via the pytree registration
added in PyAutoGalaxy PR #364, accessing fit.log_likelihood post-unflatten fails for any model
that uses AdaptImages:

AttributeError: 'NoneType' object has no attribute 'array'
  File .../rectangular_adapt_image.py", line 93, in mesh_weight_map_from
    mesh_weight_map = adapt_data.array

Root cause

FitImaging is registered with no_flatten=("dataset", "adapt_images", "settings"), so
adapt_images rides across the pytree boundary as aux — its galaxy_image_dict keys are the
trace-time ag.Galaxy instances. self.galaxies is registered as a pytree (dynamic), so
post-unflatten it contains fresh Galaxy instances built via autofit.Model.instance_unflatten
self.cls(*constructor_arguments), each with a new .id. hash(galaxy) returns int(self.id),
so the fresh Galaxy doesn't match any key in adapt_images.galaxy_image_dict. The lookup at
PyAutoGalaxy/autogalaxy/galaxy/to_inversion.py:555 raises KeyErroradapt_galaxy_image = None
mesh.mesh_weight_map_from(adapt_data=None) blows up.

The analogous fix on the autolens side that solved a similar dict-keyed-by-instance problem is
tracked at @admin_jammy/prompt/autolens/linear_light_profile_intensity_dict_pytree.md.

Note that autolens's jax_likelihood_functions/imaging/rectangular.py currently passes in
autolens_workspace_test despite apparently having the same Galaxy-identity issue — worth checking
what autolens does differently (e.g. a shared fix in autoarray / autofit, or a subtly different
FitImaging inversion path that bypasses the galaxy_image_dict lookup for the adapt mesh). That
diff may reveal the minimal fix for autogalaxy, or a broader pattern that should be lifted into
autoarray.

Scripts blocked

From @autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/, these were deferred in
the initial task 3/9 ship (PyAutoGalaxy PR #364, workspace PR on autogalaxy_workspace_test):

  • rectangular_mge.py — MGE bulge + ag.mesh.RectangularAdaptImage + ag.reg.Adapt
  • delaunay.pyag.mesh.Delaunay + ag.image_mesh.Hilbert (or Overlay, which still wires
    the image-plane mesh grid via adapt_images.galaxy_name_image_plane_mesh_grid_dict)
  • delaunay_mge.py — MGE bulge + Delaunay

The initial ship used ag.mesh.RectangularUniform + ag.reg.Constant for rectangular.py (no
adapt dependency), which does pass. After this fix lands, re-port the three scripts above using
the proper adapt-image autolens references at
@autolens_workspace_test/scripts/jax_likelihood_functions/imaging/{rectangular,rectangular_mge, delaunay,delaunay_mge}.py.

Deliverables

  1. PyAutoGalaxy library fix for the Galaxy-identity issue across the JIT boundary.
    Candidate approaches (pick whichever is cleanest):
    • Key AdaptImages.galaxy_image_dict by the galaxy's path-tuple (e.g. ('galaxies', 'galaxy'))
      instead of the Galaxy instance — stable across unflatten.
    • Look up by galaxy.id via an identity map that is rebuilt during fit_from (not carried as
      aux).
    • Register Galaxy with a custom pytree that preserves .id through unflatten so hashes match.
    • Move the adapt-image lookup inside fit_from so it runs during tracing (before the pytree
      boundary) and stores adapt_galaxy_image on the mapper directly rather than looking it up
      lazily.
  2. Unit test in test_autogalaxy/ exercising the fix without importing JAX (follow the
    numpy-only unit test convention — cross-xp checks live in workspace_test).
  3. Re-port the three deferred scripts into autogalaxy_workspace_test using the autolens
    references, and re-enable them in smoke_tests.txt alongside the existing jax_likelihood_
    functions/imaging/ entries.
  4. Add jax_likelihood_functions/imaging/delaunay_mge.py commented out in smoke_tests.txt with
    the exact jax-0.7 regression comment from autolens's smoke_tests.txt.

Dependencies

Umbrella

Follow-up from PyAutoLabs/autogalaxy_workspace_test#8 (epic #5 task 3/9). Same issue may re-surface
in task 4/9 (jax_likelihood_interferometer) and task 5/9 (jax_likelihood_multi) — either fix
once here, or carry the same deferral pattern into those tasks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions