# Imaging MGE JAX JIT Profiling — Migrate to Pytree Inputs
## Context
`autolens_workspace_developer/jax_profiling/imaging/mge.py` is the
flagship JAX JIT profiling script for the MGE source model. It
builds `Fitness.call` around a 1D parameter vector
(`jnp_params`, shape `(N,)`) and JITs that closure.
Passing parameters as a flat 1D vector has two long-term downsides:
1. **`model.instance_from_vector` is the JIT boundary**. That call
has to unpack the flat vector into every named structured
parameter (centres, ell_comps, intensities, sigmas, redshifts,
…) inside the trace. For a model with many parameters this is
pure Python shape-juggling that runs once per compile but
obscures the parameter identity under `jax.jit`.
2. **Batching with `vmap`** requires the caller to stack parameter
vectors into a 2D `(batch, N)` array. The structured parameter
names are lost along the way — any diagnostic that wants to
plot likelihood vs `einstein_radius` has to know the positional
index of `einstein_radius` in the flat vector, which is
fragile.
The rest of the codebase (tracers, galaxies, profiles) has moved
toward **pytree-native** inputs: nested dicts / `ModelInstance`
objects registered as JAX pytrees, so `jax.jit` flattens and
unflattens them automatically, and `vmap` batches over the pytree
leaves without the caller having to reshape anything.
## Task
Update `autolens_workspace_developer/jax_profiling/imaging/mge.py`
so that:
1. The JIT'd closure takes a **pytree** of parameters (e.g. a
nested dict keyed by galaxy name → profile name → parameter
name, or the `af.ModelInstance` object itself if it's now
pytree-registered) rather than a flat 1D array.
2. `jax.jit` and `jax.vmap` are applied to that pytree-accepting
closure directly — no flat-vector shim.
3. The profiling output (compile time, steady-state runtime, vmap
batch throughput) should be unchanged within noise — this is a
clarity / API-ergonomics change, not a performance change.
4. Update the inline commentary to point at the pytree inputs as
the recommended pattern for new profiling scripts.
Once `mge.py` is updated, treat it as the new reference style and
propagate the same pattern to `pixelization.py` and `delaunay.py`
as a follow-up.
## Why this matters
- Downstream users reading the script will more easily see *which*
parameter is which, because the pytree preserves names.
- `jax.jit` compile caches are keyed on the abstract pytree
structure; this change makes the cache key human-readable.
- Eventually we will want to build user-facing `jax.jit`'d
likelihood functions that accept a `ModelInstance` directly —
the profiling script should demonstrate that end-state, not an
older flat-vector style.
## Risks / blockers
- If `af.ModelInstance` is not yet registered as a JAX pytree,
this change either (a) needs that registration to happen first,
or (b) uses a plain nested dict as an interim representation.
Check `autofit`'s pytree registration status before starting;
file a blocker issue if registration is missing.
- `model.instance_from_vector(..., xp=jnp)` is currently relied on
by several other profiling scripts and by the production
`Fitness.call`. This task *only* changes the profiling script,
not the production path — do not touch `Fitness` or `model.instance_from_vector`.
## Pytree Registration Old
Older versions (12ish months ago) of PyAutoFit, PyAutoArray and PyAutoGalaxy had working Pytree circulations.
Dig through their git history first and see if you can find source code there to help you as a reference.
Overview
The flagship JAX profiling script
jax_profiling/imaging/mge.pycurrently JITs a closure that takes a flat 1D parameter vector and unpacks it viamodel.instance_from_vectorinside the trace. This obscures parameter identity underjax.jit, forcesvmapcallers to stack(batch, N)arrays, and does not match the pytree-native direction the rest of the codebase has moved toward. This task migrates the script to pytree inputs so JIT and vmap flatten/unflatten parameters automatically and preserve their names.Plan
af.ModelInstance(and supporting classes in autogalaxy / autoarray as needed) so it can cross ajax.jitboundary directly.jnp_paramsJIT input inmge.pywith a structured pytree.mapping_matrix_from_params,blurred_mm_from_params, the full-pipeline JIT and the vmap path to consume the pytree directly — removingmodel.instance_from_vectorfrom inside the trace.pixelization.pyanddelaunay.py.Detailed implementation plan
Affected Repositories
PyAutoLabs/autolens_workspace_developer(primary — originating request)rhayes777/PyAutoFit(pytree registration onPriorModel/ModelInstance)Jammy2211/PyAutoArray(potentially — extend existinguniform.py/ triangle pytree wrapping)Jammy2211/PyAutoGalaxy(potentially —Galaxy,Tracer, light/mass profiles)Work Classification
Library + Workspace.
/start_libraryfirst;autolens_workspace_developerscript changes follow once library-side registration is in place.Historical pytree references
autofit/mapper/prior_model/prior_model.py:281,289tree_flatten/tree_unflattenonPriorModelalready present;register_pytree_nodecall at line 211 is currently commented out — check git blame for the reason before re-enablingfeature/jax_pytree, PR #1103feature/graphical_pytreesAnalysisFactor,FactorGraphModelpytree methods, prior-id children9a8d17fe"tree flatten point and circle" (Sep 2024)autoarray/structures/triangles/shape.pye6419042+ PR #137feature/jax_tracer(Oct 2024)autoarray/operators/over_sampling/uniform.py— still present todayfeature/jax_mge,feature/jax_tracerbranchesBranch Survey
z_staging/imaging-mge-pytree-migration-dirty-snapshot/; do not touch)Suggested branch:
feature/imaging-mge-pytree-migration(used consistently across all affected repos)Worktree root:
~/Code/PyAutoLabs-wt/imaging-mge-pytree-migration/(created later by/start_library)Implementation Steps
git show/git diffthe historical commits listed above to understand the prior pytree approach. Readautofit/mapper/prior_model/prior_model.py:205-295and investigate whyregister_pytree_nodewas commented out.register_pytree_nodeonPriorModel(andModelInstanceif needed) in PyAutoFit. If the reason for the original comment-out still applies, work around it (e.g. selective leaf registration) rather than blindly uncommenting.Galaxy,Tracer, and thelp_linear.MultiGaussian/mp.Isothermal/mp.ExternalShearprofile classes used inmge.py— only as far as needed to let aModelInstancecross a JIT boundary and be rebuilt into aTracerinside the trace.jax.jit/jax.vmapover a simpleModelInstance.autolens_workspace_developer/jax_profiling/imaging/mge.py:instance = model.instance_from_vector(vector=param_vector)(setup, not JIT), passinstance(pytree-registered) directly as the JIT input.mapping_matrix_from_params(mge.py:336): accept the pytree, buildTracerviaal.Tracer(galaxies=list(instance.galaxies))— noinstance_from_vectorin the trace.blurred_mm_from_params(mge.py:373): same change.fitness.call(jnp_params)profile block (mge.py:554-563) with a pytree-accepting closure defined inside the script. Do NOT modifyFitness.callitself (explicit out-of-scope per original prompt).jax.tree_util.tree_map(lambda leaf: jnp.broadcast_to(leaf, (batch_size, *leaf.shape)), instance), vmap the new closure, keep the correctness assertion.batched_callat the pytree closure.cd jax_profiling/imaging && NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib python mge.py. Correctness asserts must pass; timings within noise of prior JSON results.Key Files
PyAutoFit/autofit/mapper/prior_model/prior_model.py— re-enable pytree registrationPyAutoGalaxy/autogalaxy/galaxy/galaxy.py,.../profiles/**— extend registration as neededPyAutoLens/autolens/lens/tracer.py— extend registration as neededautolens_workspace_developer/jax_profiling/imaging/mge.py— the migrationOut of scope (explicit)
PyAutoFit/autofit/non_linear/fitness.py(Fitness.call)PyAutoFitmodel.instance_from_vector(production path)jax_profiling/imaging/pixelization.py— follow-upjax_profiling/imaging/delaunay.py— follow-upautolens_workspace_developer(snapshotted toz_staging/imaging-mge-pytree-migration-dirty-snapshot/; left alone)Original Prompt
Click to expand starting prompt