Skip to content

feat: scripts/jax_grad/multi/ — lp.py + mge.py from scratch (task 8/9, epic-final)#33

Merged
Jammy2211 merged 1 commit intomainfrom
feature/autogalaxy-wst-jax-grad-multi
May 6, 2026
Merged

feat: scripts/jax_grad/multi/ — lp.py + mge.py from scratch (task 8/9, epic-final)#33
Jammy2211 merged 1 commit intomainfrom
feature/autogalaxy-wst-jax-grad-multi

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Create scripts/jax_grad/multi/{lp.py, mge.py} from scratch — autolens has no multi jax_grad reference today. Each script joins per-band AnalysisImaging factors via af.FactorGraphModel(use_jax=True) and wraps the global log-likelihood in jax.value_and_grad. The gradient runs over the joint parameter vector spanning the g and r bands.

This is the final task (8/9) of the autogalaxy_workspace_test parity epic (#5). With this PR merged, the parity epic is complete.

Scripts Changed

  • scripts/jax_grad/multi/__init__.py — new (empty namespace marker).
  • scripts/jax_grad/multi/lp.py — new. Two-band Sersic model on dataset/multi/jax_test/ (mirrors jax_likelihood_functions/multi/lp.py). Option B per-band bulge.ell_comps priors via model.copy() + af.GaussianPrior; everything else shared. Fitness(model=factor_graph.global_prior_model, analysis=factor_graph) wraps jax.value_and_grad(fitness.call). Assertions: finite value, shape (model.total_free_parameters,), all-finite grad, not-all-zero. Local run: 16.8s, gradient shape (9,).
  • scripts/jax_grad/multi/mge.py — new. Same multi-band scaffold; bulge swapped to MGE linear basis (mge_model_from, 20 gaussians, centre_prior_is_uniform=True). Per-band ell_comps re-tied across all gaussians within each basis. Same jax.value_and_grad body and assertions. Local run: 35.6s, gradient shape (6,).
  • smoke_tests.txt — appended both new scripts immediately after jax_grad/interferometer/mge.py.

Notes

Layout divergence from autolens (final). Autolens has flat top-level jax_grad/imaging_*.py and no interferometer or multi jax_grad scripts at all. This PR closes out the subfolder convention started in PR #29 (task 6) and continued in PR #31 (task 7) — autogalaxy now has the full jax_grad/{imaging,interferometer,multi}/ set. With all three subdirectories landed, the autolens-retrofit follow-up can now be filed as a single migration covering jax_grad/imaging_lp.pyjax_grad/imaging/lp.py and net-new interferometer/multi ports.

lp.Sersic (not lp_linear). Same rationale as PR #31 — matches the validated jax_likelihood_functions/multi/lp.py pattern.

No env_vars.yaml change. The jax_grad/ substring override added in PR #29 covers jax_grad/multi/.

MGE runtime (35.6s). On the slower end of the smoke-budget range — JIT-tracing a multi-band MGE FactorGraph + value_and_grad is heavier than the single-band scripts. Still well under the per-script smoke budget; flagging if future tasks add more variants on top of this.

Test Plan

  • Both new scripts pass locally via .github/scripts/run_smoke.py with the env override applied:
    • jax_grad/multi/lp.py — exit 0 in 16.8s, grad shape (9,)
    • jax_grad/multi/mge.py — exit 0 in 35.6s, grad shape (6,)

🤖 Generated with Claude Code

Closes #32.

…, epic-final)

Create autogalaxy multi-wavelength jax_grad scripts under
jax_grad/multi/. Autolens has no multi jax_grad reference today, so
these are greenfield ports modelled on the matching
jax_likelihood_functions/multi/{lp,mge}.py scripts.

Each script wraps the FactorGraphModel log-likelihood (per-band
AnalysisImaging factors joined via af.FactorGraphModel) in
jax.value_and_grad. The gradient runs over the joint global parameter
vector spanning g and r bands. Asserts finite, expected shape, and
not-all-zero.

Closes #32 — final task of the autogalaxy_workspace_test parity epic
(#5).
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 6, 2026
@Jammy2211 Jammy2211 merged commit 8c62a8d into main May 6, 2026
0 of 4 checks passed
@Jammy2211 Jammy2211 deleted the feature/autogalaxy-wst-jax-grad-multi branch May 6, 2026 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: scripts/jax_grad/multi/ — lp.py + mge.py from scratch (task 8/9, epic-final)

1 participant