Skip to content

refactor: migrate imaging/mge.py JIT closure to pytree parameter inputs #10

@Jammy2211

Description

@Jammy2211

Overview

The flagship JAX profiling script jax_profiling/imaging/mge.py currently JITs a closure that takes a flat 1D parameter vector and unpacks it via model.instance_from_vector inside the trace. This obscures parameter identity under jax.jit, forces vmap callers to stack (batch, N) arrays, and does not match the pytree-native direction the rest of the codebase has moved toward. This task migrates the script to pytree inputs so JIT and vmap flatten/unflatten parameters automatically and preserve their names.

Plan

  • Dig through ~12-month-old git history for prior pytree implementations in PyAutoFit, PyAutoArray and PyAutoGalaxy; use them as templates rather than designing from scratch.
  • Re-enable / extend pytree registration on af.ModelInstance (and supporting classes in autogalaxy / autoarray as needed) so it can cross a jax.jit boundary directly.
  • Replace the flat jnp_params JIT input in mge.py with a structured pytree.
  • Update mapping_matrix_from_params, blurred_mm_from_params, the full-pipeline JIT and the vmap path to consume the pytree directly — removing model.instance_from_vector from inside the trace.
  • Keep profiling outputs (per-step timings, JSON summary, bar chart) format-compatible; timings should match within noise.
  • Refresh the module docstring to document pytree inputs as the recommended pattern for new profiling scripts.
  • Follow-up (separate work): propagate the same pattern to pixelization.py and delaunay.py.
Detailed implementation plan

Affected Repositories

  • PyAutoLabs/autolens_workspace_developer (primary — originating request)
  • rhayes777/PyAutoFit (pytree registration on PriorModel / ModelInstance)
  • Jammy2211/PyAutoArray (potentially — extend existing uniform.py / triangle pytree wrapping)
  • Jammy2211/PyAutoGalaxy (potentially — Galaxy, Tracer, light/mass profiles)

Work Classification

Library + Workspace. /start_library first; autolens_workspace_developer script changes follow once library-side registration is in place.

Historical pytree references

Repo Commit / PR Content
PyAutoFit autofit/mapper/prior_model/prior_model.py:281,289 tree_flatten / tree_unflatten on PriorModel already present; register_pytree_node call at line 211 is currently commented out — check git blame for the reason before re-enabling
PyAutoFit PR #1037 feature/jax_pytree, PR #1103 feature/graphical_pytrees AnalysisFactor, FactorGraphModel pytree methods, prior-id children
PyAutoArray 9a8d17fe "tree flatten point and circle" (Sep 2024) autoarray/structures/triangles/shape.py
PyAutoArray e6419042 + PR #137 feature/jax_tracer (Oct 2024) autoarray/operators/over_sampling/uniform.py — still present today
PyAutoGalaxy feature/jax_mge, feature/jax_tracer branches precursor profile/tracer pytree work

Branch Survey

Repository Current Branch Dirty?
./autolens_workspace_developer main dirty (unrelated fits / json / png results + untracked probe scripts; snapshotted to z_staging/imaging-mge-pytree-migration-dirty-snapshot/; do not touch)
./PyAutoFit main (check at worktree creation)
./PyAutoArray main (check at worktree creation)
./PyAutoGalaxy main (check at worktree creation)

Suggested branch: feature/imaging-mge-pytree-migration (used consistently across all affected repos)
Worktree root: ~/Code/PyAutoLabs-wt/imaging-mge-pytree-migration/ (created later by /start_library)

Implementation Steps

  1. git show / git diff the historical commits listed above to understand the prior pytree approach. Read autofit/mapper/prior_model/prior_model.py:205-295 and investigate why register_pytree_node was commented out.
  2. Re-enable register_pytree_node on PriorModel (and ModelInstance if needed) in PyAutoFit. If the reason for the original comment-out still applies, work around it (e.g. selective leaf registration) rather than blindly uncommenting.
  3. Extend pytree registration to Galaxy, Tracer, and the lp_linear.MultiGaussian / mp.Isothermal / mp.ExternalShear profile classes used in mge.py — only as far as needed to let a ModelInstance cross a JIT boundary and be rebuilt into a Tracer inside the trace.
  4. Add unit tests in PyAutoFit / PyAutoGalaxy covering flatten → unflatten round-trip and jax.jit / jax.vmap over a simple ModelInstance.
  5. In autolens_workspace_developer/jax_profiling/imaging/mge.py:
    • After instance = model.instance_from_vector(vector=param_vector) (setup, not JIT), pass instance (pytree-registered) directly as the JIT input.
    • Rewrite mapping_matrix_from_params (mge.py:336): accept the pytree, build Tracer via al.Tracer(galaxies=list(instance.galaxies)) — no instance_from_vector in the trace.
    • Rewrite blurred_mm_from_params (mge.py:373): same change.
    • Replace the full-pipeline fitness.call(jnp_params) profile block (mge.py:554-563) with a pytree-accepting closure defined inside the script. Do NOT modify Fitness.call itself (explicit out-of-scope per original prompt).
    • vmap section (mge.py:572-603): build batched pytree via jax.tree_util.tree_map(lambda leaf: jnp.broadcast_to(leaf, (batch_size, *leaf.shape)), instance), vmap the new closure, keep the correctness assertion.
    • Static memory analysis block (mge.py:609-621): point batched_call at the pytree closure.
    • Update the top-of-file docstring and inline comments to document pytree inputs as the recommended pattern.
  6. Verify: cd jax_profiling/imaging && NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib python mge.py. Correctness asserts must pass; timings within noise of prior JSON results.

Key Files

  • PyAutoFit/autofit/mapper/prior_model/prior_model.py — re-enable pytree registration
  • PyAutoGalaxy/autogalaxy/galaxy/galaxy.py, .../profiles/** — extend registration as needed
  • PyAutoLens/autolens/lens/tracer.py — extend registration as needed
  • autolens_workspace_developer/jax_profiling/imaging/mge.py — the migration

Out of scope (explicit)

  • PyAutoFit/autofit/non_linear/fitness.py (Fitness.call)
  • PyAutoFit model.instance_from_vector (production path)
  • jax_profiling/imaging/pixelization.py — follow-up
  • jax_profiling/imaging/delaunay.py — follow-up
  • Any dirty/untracked files in autolens_workspace_developer (snapshotted to z_staging/imaging-mge-pytree-migration-dirty-snapshot/; left alone)

Original Prompt

Click to expand starting prompt
# Imaging MGE JAX JIT Profiling — Migrate to Pytree Inputs

## Context

`autolens_workspace_developer/jax_profiling/imaging/mge.py` is the
flagship JAX JIT profiling script for the MGE source model. It
builds `Fitness.call` around a 1D parameter vector
(`jnp_params`, shape `(N,)`) and JITs that closure.

Passing parameters as a flat 1D vector has two long-term downsides:

1. **`model.instance_from_vector` is the JIT boundary**. That call
   has to unpack the flat vector into every named structured
   parameter (centres, ell_comps, intensities, sigmas, redshifts,
   …) inside the trace. For a model with many parameters this is
   pure Python shape-juggling that runs once per compile but
   obscures the parameter identity under `jax.jit`.
2. **Batching with `vmap`** requires the caller to stack parameter
   vectors into a 2D `(batch, N)` array. The structured parameter
   names are lost along the way — any diagnostic that wants to
   plot likelihood vs `einstein_radius` has to know the positional
   index of `einstein_radius` in the flat vector, which is
   fragile.

The rest of the codebase (tracers, galaxies, profiles) has moved
toward **pytree-native** inputs: nested dicts / `ModelInstance`
objects registered as JAX pytrees, so `jax.jit` flattens and
unflattens them automatically, and `vmap` batches over the pytree
leaves without the caller having to reshape anything.

## Task

Update `autolens_workspace_developer/jax_profiling/imaging/mge.py`
so that:

1. The JIT'd closure takes a **pytree** of parameters (e.g. a
   nested dict keyed by galaxy name → profile name → parameter
   name, or the `af.ModelInstance` object itself if it's now
   pytree-registered) rather than a flat 1D array.
2. `jax.jit` and `jax.vmap` are applied to that pytree-accepting
   closure directly — no flat-vector shim.
3. The profiling output (compile time, steady-state runtime, vmap
   batch throughput) should be unchanged within noise — this is a
   clarity / API-ergonomics change, not a performance change.
4. Update the inline commentary to point at the pytree inputs as
   the recommended pattern for new profiling scripts.

Once `mge.py` is updated, treat it as the new reference style and
propagate the same pattern to `pixelization.py` and `delaunay.py`
as a follow-up.

## Why this matters

- Downstream users reading the script will more easily see *which*
  parameter is which, because the pytree preserves names.
- `jax.jit` compile caches are keyed on the abstract pytree
  structure; this change makes the cache key human-readable.
- Eventually we will want to build user-facing `jax.jit`'d
  likelihood functions that accept a `ModelInstance` directly —
  the profiling script should demonstrate that end-state, not an
  older flat-vector style.

## Risks / blockers

- If `af.ModelInstance` is not yet registered as a JAX pytree,
  this change either (a) needs that registration to happen first,
  or (b) uses a plain nested dict as an interim representation.
  Check `autofit`'s pytree registration status before starting;
  file a blocker issue if registration is missing.
- `model.instance_from_vector(..., xp=jnp)` is currently relied on
  by several other profiling scripts and by the production
  `Fitness.call`. This task *only* changes the profiling script,
  not the production path — do not touch `Fitness` or `model.instance_from_vector`.

## Pytree Registration Old

Older versions (12ish months ago) of PyAutoFit, PyAutoArray and PyAutoGalaxy had working Pytree circulations.
Dig through their git history first and see if you can find source code there to help you as a reference.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions