Conversation
…, epic-final)
Create autogalaxy multi-wavelength jax_grad scripts under
jax_grad/multi/. Autolens has no multi jax_grad reference today, so
these are greenfield ports modelled on the matching
jax_likelihood_functions/multi/{lp,mge}.py scripts.
Each script wraps the FactorGraphModel log-likelihood (per-band
AnalysisImaging factors joined via af.FactorGraphModel) in
jax.value_and_grad. The gradient runs over the joint global parameter
vector spanning g and r bands. Asserts finite, expected shape, and
not-all-zero.
Closes #32 — final task of the autogalaxy_workspace_test parity epic
(#5).
This was referenced May 6, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Create
scripts/jax_grad/multi/{lp.py, mge.py}from scratch — autolens has no multijax_gradreference today. Each script joins per-bandAnalysisImagingfactors viaaf.FactorGraphModel(use_jax=True)and wraps the global log-likelihood injax.value_and_grad. The gradient runs over the joint parameter vector spanning the g and r bands.This is the final task (8/9) of the autogalaxy_workspace_test parity epic (#5). With this PR merged, the parity epic is complete.
Scripts Changed
scripts/jax_grad/multi/__init__.py— new (empty namespace marker).scripts/jax_grad/multi/lp.py— new. Two-band Sersic model ondataset/multi/jax_test/(mirrorsjax_likelihood_functions/multi/lp.py). Option B per-bandbulge.ell_compspriors viamodel.copy()+af.GaussianPrior; everything else shared.Fitness(model=factor_graph.global_prior_model, analysis=factor_graph)wrapsjax.value_and_grad(fitness.call). Assertions: finite value, shape(model.total_free_parameters,), all-finite grad, not-all-zero. Local run: 16.8s, gradient shape(9,).scripts/jax_grad/multi/mge.py— new. Same multi-band scaffold; bulge swapped to MGE linear basis (mge_model_from, 20 gaussians,centre_prior_is_uniform=True). Per-bandell_compsre-tied across all gaussians within each basis. Samejax.value_and_gradbody and assertions. Local run: 35.6s, gradient shape(6,).smoke_tests.txt— appended both new scripts immediately afterjax_grad/interferometer/mge.py.Notes
Layout divergence from autolens (final). Autolens has flat top-level
jax_grad/imaging_*.pyand no interferometer or multijax_gradscripts at all. This PR closes out the subfolder convention started in PR #29 (task 6) and continued in PR #31 (task 7) — autogalaxy now has the fulljax_grad/{imaging,interferometer,multi}/set. With all three subdirectories landed, the autolens-retrofit follow-up can now be filed as a single migration coveringjax_grad/imaging_lp.py→jax_grad/imaging/lp.pyand net-new interferometer/multi ports.lp.Sersic(notlp_linear). Same rationale as PR #31 — matches the validatedjax_likelihood_functions/multi/lp.pypattern.No env_vars.yaml change. The
jax_grad/substring override added in PR #29 coversjax_grad/multi/.MGE runtime (35.6s). On the slower end of the smoke-budget range — JIT-tracing a multi-band MGE FactorGraph + value_and_grad is heavier than the single-band scripts. Still well under the per-script smoke budget; flagging if future tasks add more variants on top of this.
Test Plan
.github/scripts/run_smoke.pywith the env override applied:jax_grad/multi/lp.py— exit 0 in 16.8s, grad shape(9,)jax_grad/multi/mge.py— exit 0 in 35.6s, grad shape(6,)🤖 Generated with Claude Code
Closes #32.