Skip to content

feat: AnalysisImaging pytree registration + jax_likelihood_functions/imaging/ scripts #8

@Jammy2211

Description

@Jammy2211

Overview

Adds scripts/jax_likelihood_functions/imaging/ to autogalaxy_workspace_test — autogalaxy ports
of the 8 autolens_workspace_test JAX-likelihood imaging scripts (excluding the two *_dspl.py
lens-specific variants). Unblocks the jax.jit(analysis.fit_from) scalar round-trip on
autogalaxy's imaging path.

Prerequisite library work: autogalaxy.imaging.model.analysis.AnalysisImaging has no
_register_fit_imaging_pytrees method today, so jax.jit(fit_from) cannot flatten its
FitImaging return value. A library PR on PyAutoGalaxy is shipped first, mirroring
autolens's implementation.

Part of epic #5 (task 3/9).

Plan

  • Library PR on PyAutoGalaxy: add _register_fit_imaging_pytrees to AnalysisImaging, modelled
    verbatim on autolens.imaging.model.analysis.AnalysisImaging._register_fit_imaging_pytrees.
    Register autogalaxy's FitImaging (with no_flatten=("dataset", "adapt_images", "settings")),
    the shared DatasetModel (idempotent — autolens already registers it), and the galaxies
    container that fit_from passes through (ag.Galaxies if used, or the raw list — verify at
    implementation). Call the method from fit_from under the existing use_jax gate.
  • Merge the library PR, pull, then port 8 scripts to
    autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/.
  • Each script follows the three-step pattern (NumPy baseline → jax.jit(fit_from) → scalar
    log_likelihood match) and prints PASS: jit(fit_from) round-trip matches NumPy scalar..
  • Append scripts to smoke_tests.txt. Disable delaunay_mge.py with the jax-0.7 regression
    comment matching the autolens entry.
  • Workspace PR. Library-first merge gate enforced by /ship_workspace.
Detailed implementation plan

Affected Repositories

  • autogalaxy_workspace_test (primary — this issue)
  • PyAutoGalaxy (library PR, separate issue + PR)

Work Classification

Both (library + workspace). Library ships first.

Branch Survey

Repository Current Branch Dirty?
./PyAutoGalaxy main clean
./autogalaxy_workspace_test main clean

Suggested branch: feature/autogalaxy-wst-jax-lh-imaging
Task name: autogalaxy-wst-jax-lh-imaging
Worktree root: ~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-lh-imaging/ (created later by /start_library)

Library Implementation Steps (PyAutoGalaxy — ships first)

  1. autogalaxy/imaging/model/analysis.py:
    • Add a _register_fit_imaging_pytrees @staticmethod modelled exactly on the autolens
      equivalent at autolens/imaging/model/analysis.py:131. Lazy-import register_instance_pytree,
      DatasetModel, and autogalaxy's FitImaging + Galaxies inside the method (same lazy-import
      pattern autolens uses to avoid circular imports).
    • Register:
      • FitImaging with no_flatten=("dataset", "adapt_images", "settings") — mirror of autolens.
      • DatasetModel (no no_flatten) — idempotent; register_instance_pytree checks a
        registry set and skips if already registered.
      • The galaxies container returned by galaxies_via_instance_from — verify whether this is
        a Galaxies instance (which subclasses List) or a plain list. If Galaxies, register
        it (no no_flatten needed; no cosmology attribute). If plain list, JAX handles it
        natively.
    • Call the method from fit_from under if self._use_jax: just like autolens's line 111–112.
  2. test_autogalaxy/imaging/model/: add a minimal test that constructs AnalysisImaging(use_jax=True),
    calls fit_from(instance) once, and asserts the registered classes appear in
    autoarray.abstract_ndarray._pytree_registered_classes. Mirrors the autolens test pattern.
  3. Library PR on PyAutoGalaxy. CI green. Merge → pull → move to workspace.

Workspace Implementation Steps (autogalaxy_workspace_test — ships second)

For each of the 8 autolens scripts (simulator.py, lp.py, mge.py, mge_group.py,
rectangular.py, rectangular_mge.py, delaunay.py, delaunay_mge.py):

  1. Read the autolens reference, strip lens-specific constructs:
    • al.Tracer / al.Galaxy(redshift=0.5, …) / al.Galaxy(redshift=1.0, …) → a single plane
      of ag.Galaxy (no source/lens split).
    • al.AnalysisImagingag.AnalysisImaging.
    • Drop deflections / mass_profiles / convergence uses where they have no galaxy-side
      counterpart (the lp/MGE source/pixelization paths should all have direct analogues).
  2. Keep the three-step JAX pattern:
    • NumPy: fit_numpy = analysis.fit_from(instance)fit_numpy.log_likelihood
    • JAX JIT: fit_jax = jax.jit(analysis.fit_from)(instance)fit_jax.log_likelihood
    • Assert scalar equality to 1e-6 (or whatever tolerance the autolens reference uses).
  3. Each script prints PASS: jit(fit_from) round-trip matches NumPy scalar..
  4. scripts/jax_likelihood_functions/__init__.py + scripts/jax_likelihood_functions/imaging/__init__.py.
  5. Append to smoke_tests.txt with jax_likelihood_functions/imaging/delaunay_mge.py commented
    out using the exact autolens comment (jax 0.7 regression, references
    admin_jammy/prompt/build/smoke_workspace_fixes.md).
  6. Workspace PR with ## Upstream PR linking to the library PR.

Known Spawn-offs

If either of these surfaces during implementation, stop and /start_dev a separate library task:

  • Linear light profile pytree: linear_light_profile_intensity_dict_pytree identity issue
    (counterpart of the autolens-side fix at admin_jammy/prompt/autolens/linear_light_profile_intensity_dict_pytree.md).
    Only blocks fit_for_visualization, not this task's fit_from scalar round-trip.
  • Per-profile pytree registration if any autogalaxy profile isn't registered yet (follow
    pattern in admin_jammy/prompt/issued/fit_imaging_pytree_*.md).

Excluded from this task

  • rectangular_dspl.py, simulator_dspl.py — double-source-plane, lens-specific.
  • AnalysisInterferometer pytree registration — covered by task 4 (autogalaxy_workspace_test_jax_likelihood_interferometer.md), which will open its own PyAutoGalaxy library PR.
  • fit_for_visualization JIT path — out of scope for the scalar fit_from round-trip.

Key Files

Library (PyAutoGalaxy):

  • autogalaxy/imaging/model/analysis.py — add _register_fit_imaging_pytrees + call site.
  • test_autogalaxy/imaging/model/test_analysis.py (or nearest existing) — registration test.

Workspace (autogalaxy_workspace_test):

  • scripts/jax_likelihood_functions/__init__.py (new)
  • scripts/jax_likelihood_functions/imaging/__init__.py (new)
  • scripts/jax_likelihood_functions/imaging/{simulator,lp,mge,mge_group,rectangular,rectangular_mge,delaunay,delaunay_mge}.py (new)
  • smoke_tests.txt — append 8 scripts (delaunay_mge commented out).

Original Prompt

Click to expand starting prompt

Create scripts/jax_likelihood_functions/imaging/ in @autogalaxy_workspace_test with autogalaxy
ports of every autolens JAX-likelihood imaging script, excluding the *_dspl.py double-source-
plane variants (lens-specific, no autogalaxy analogue).

See admin_jammy/prompt/issued/autogalaxy_workspace_test_jax_likelihood_imaging.md for full deliverables.

Part of umbrella epic #5 (task 3/9).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions