Skip to content

test: EllipseMultipoleScaled JAX parity script#51

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/multipole-scaled-jax
May 19, 2026
Merged

test: EllipseMultipoleScaled JAX parity script#51
Jammy2211 merged 1 commit into
mainfrom
feature/multipole-scaled-jax

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Adds scripts/jax_likelihood_functions/ellipse/multipoles_scaled.py — the JAX parity test for EllipseMultipoleScaled, mirroring the existing multipoles.py (EllipseMultipole). This closes the variant-coverage gap that allowed PyAutoGalaxy#426 to slip through prompt-7's verification: EllipseMultipoleScaled had no vmap-validation script, so its __init__-time np.* call on JAX tracers stayed invisible to smoke until a user HPC job hit it.

The new script runs numpy reference → fitness._vmap batch evaluation → jax.jit(analysis.fit_from)(instance) round-trip with np.testing.assert_allclose(rtol=1e-4) against the numpy total_log_likelihood. Same shape as multipoles.py so any future regression in either variant gets caught.

Caveat for reviewers: with multipole.major_axis=1.0 the scaling factor k = k_scaled * 1.0 is mathematically a no-op, so the numerical reference values match the non-scaled multipoles.py. The script still exercises EllipseMultipoleScaled.points_perturbed_from (including the deferred k_scaled * major_axis derivation under vmap/jit) — that's the code path that was broken before PR #427. Strengthening with a non-trivial major_axis case is a future enhancement.

Scripts Changed

  • scripts/jax_likelihood_functions/ellipse/multipoles_scaled.py — new (~160 lines). Mirrors multipoles.py line-by-line but builds the multipole as af.Model(ag.EllipseMultipoleScaled) with scaled_multipole_comps priors and major_axis=1.0. Numpy reference, vmap block (50-batch through Fitness._vmap), and JIT round-trip block (jax.jit(analysis_jit.fit_from)(instance) + assert_allclose(rtol=1e-4) + isinstance(fit.log_likelihood, jnp.ndarray)).

Upstream PR

PyAutoLabs/PyAutoGalaxy#427

Test Plan

  • Smoke tests pass for autogalaxy_workspace_test (and adjacent workspaces, since smoke runs all 6)
  • python scripts/jax_likelihood_functions/ellipse/multipoles_scaled.py prints PASS: jit(fit_from) round-trip matches NumPy scalar.
  • Library-first merge gate: PR #427 must merge before this PR

🤖 Generated with Claude Code

Mirrors the existing multipoles.py (EllipseMultipole coverage) for the
scaled variant: numpy reference + fitness._vmap batch + jax.jit(fit_from)
round-trip with rtol=1e-4 assertion against the numpy total_log_likelihood.

This closes the variant-coverage gap that allowed PyAutoGalaxy#426 to
slip through the prompt-7 verification — going forward any regression
in EllipseMultipoleScaled's JAX path gets caught by smoke.

Caveat: with multipole.major_axis=1.0 the scaling factor is a no-op
(k = k_scaled * 1.0), so the numerical reference values match the
non-scaled multipoles.py. The script still exercises the
EllipseMultipoleScaled.points_perturbed_from code path including the
deferred k_scaled * major_axis derivation under vmap/jit — that's the
code path that was broken before PR #427. Adding a non-trivial
major_axis case is a future enhancement.

Issue PyAutoGalaxy#426. Upstream library PR: PyAutoGalaxy#427.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 19, 2026
@Jammy2211 Jammy2211 merged commit 433323d into main May 19, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/multipole-scaled-jax branch May 19, 2026 10:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant