Skip to content

feat: scripts/jax_grad/multi/ — lp.py + mge.py from scratch (task 8/9, epic-final) #32

@Jammy2211

Description

@Jammy2211

Overview

Create scripts/jax_grad/multi/{lp.py, mge.py} in autogalaxy_workspace_test from scratch — autolens has no multi jax_grad reference. Each script wraps an af.FactorGraphModel over per-band AnalysisImaging factors in jax.value_and_grad and asserts the gradient is finite, the shape matches the joint global parameter vector spanning g and r bands, and the gradient isn't all-zero. This is the final task (8/9) of the autogalaxy_workspace_test parity epic (#5).

Plan

  • Create scripts/jax_grad/multi/{__init__.py, lp.py, mge.py} from scratch — autolens has no multi jax_grad reference. The multi pipeline uses af.FactorGraphModel + per-band AnalysisImaging factors with option-B per-band ell_comps.
  • Each script wraps the FactorGraphModel.log_likelihood (via Fitness(model=factor_graph.global_prior_model, analysis=factor_graph)) in jax.value_and_grad. Gradient runs over the joint global parameter vector spanning both g and r bands.
  • Assertions match tasks 6 and 7: finite value, shape matches factor_graph.global_prior_model.total_free_parameters, all-finite grad, not-all-zero grad.
  • Wire into smoke_tests.txt after jax_grad/interferometer/mge.py. The jax_grad/ env override from PR feat: scripts/jax_grad/imaging/ — lp.py + mge.py port (task 6/9) #29 already covers jax_grad/multi/.
Detailed implementation plan

Affected Repositories

  • autogalaxy_workspace_test (primary)

Work Classification

Workspace

Branch Survey

Repository Current Branch Dirty? Notes
./autogalaxy_workspace_test main dirty (3 modified dataset/*/jax_test/galaxies.json + untracked test_report.md) Pre-existing dirt unrelated to this task. Worktree starts clean from origin/main (now at 4630b7c after task 7 / PR #31 merged).

Suggested branch: feature/autogalaxy-wst-jax-grad-multi
Worktree root: ~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-grad-multi/ (created later by /start_workspace)

Implementation Steps

  1. Create scripts/jax_grad/multi/__init__.py (empty).
  2. scripts/jax_grad/multi/lp.py — mirror jax_likelihood_functions/multi/lp.py for setup:
    • waveband_list = ["g", "r"], pixel_scales=0.1, mask_radius=3.0, dataset/multi/jax_test/{g,r}_{data,psf,noise_map}.fits, with subprocess fallback to scripts/jax_likelihood_functions/multi/simulator.py.
    • Per-band Imaging.from_fitsapply_maskapply_over_sampling(over_sample_size_lp=1).
    • Model: bulge = af.Model(ag.lp.Sersic); galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge); model = af.Collection(galaxies=af.Collection(galaxy=galaxy)).
    • Option B: model_per_band_list = [model.copy() with galaxies.galaxy.bulge.ell_comps.ell_comps_0/1 = af.GaussianPrior(mean=0.0, sigma=0.5)].
    • AnalysisImaging per band → AnalysisFactor per band → FactorGraphModel(*..., use_jax=True).
    • Fitness(model=factor_graph.global_prior_model, analysis=factor_graph, fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99).
    • param_vector = jnp.array(factor_graph.global_prior_model.physical_values_from_prior_medians), ell_comps perturbation, value, grad = jax.value_and_grad(fitness.call)(param_vector).
    • Four assertions on factor_graph.global_prior_model.total_free_parameters. Trailing print("jax_grad/multi/lp.py JAX gradient checks passed.").
  3. scripts/jax_grad/multi/mge.py — same multi-band scaffold, MGE bulge instead of Sersic. Per-band ell_comps re-tying loop from jax_likelihood_functions/multi/mge.py:
    for gaussian in model_analysis.galaxies.galaxy.bulge.profile_list:
        gaussian.ell_comps.ell_comps_0 = ec_0
        gaussian.ell_comps.ell_comps_1 = ec_1
    Same factor-graph + Fitness + jax.value_and_grad body as lp.py.
  4. Append two lines to smoke_tests.txt immediately after jax_grad/interferometer/mge.py:
    jax_grad/multi/lp.py
    jax_grad/multi/mge.py
    
  5. No env_vars.yaml change — the jax_grad/ substring override added in PR feat: scripts/jax_grad/imaging/ — lp.py + mge.py port (task 6/9) #29 already matches jax_grad/multi/.
  6. Local verify each via .github/scripts/run_smoke.py:run_one().

Key Files

  • autogalaxy_workspace_test/scripts/jax_grad/multi/__init__.py — new (empty namespace marker).
  • autogalaxy_workspace_test/scripts/jax_grad/multi/lp.py — new (Sersic, multi-band FactorGraph, jax.value_and_grad).
  • autogalaxy_workspace_test/scripts/jax_grad/multi/mge.py — new (MGE, multi-band FactorGraph, jax.value_and_grad).
  • autogalaxy_workspace_test/smoke_tests.txt — two new lines after jax_grad/interferometer/mge.py.

Risks & Notes

  • Multi pipeline JAX cost: factor-graph jax.value_and_grad may be heavier than single-band — each grad call traces both band graphs. Acceptable if completes in <30s. The matching jax_likelihood_functions/multi/lp.py script runs in similar time (vmap is similarly expensive).
  • lp.Sersic (not lp_linear): keeping the validated jax_likelihood_functions/multi/lp.py pattern. Linear path under multi+JAX hasn't been validated for grad.
  • autolens layout divergence (final): tasks 6/7/8 establish the subfolder convention. After this PR ships, suggest filing the autolens-retrofit follow-up issue covering all three subdirectories.
  • Pytree prerequisite: factor_graph.global_prior_model registration scaffolding from task 5 / PR feat: jax_likelihood_functions/multi/ port #19 underpins the multi grad path.
  • Existing jax_grad/ env override from PR feat: scripts/jax_grad/imaging/ — lp.py + mge.py port (task 6/9) #29 already unsets PYAUTO_SMALL_DATASETS and PYAUTO_DISABLE_JAX for any path containing jax_grad/.

Original Prompt

Click to expand starting prompt

Create scripts/jax_grad/multi/ in @autogalaxy_workspace_test exercising jax.grad on the
autogalaxy multi-dataset likelihood path.

Layout note

Same subfolder-vs-flat divergence as tasks 6 and 7. Flag in PR body.

Scripts

autolens has no multi jax_grad scripts today. Create from scratch using the
jax_likelihood_functions/multi/ templates (task 5).

Minimum coverage:

  • jax_grad/multi/lp.py
  • jax_grad/multi/mge.py

Additional variants only if feasible.

Skip: *_dspl.py.

Pytree prerequisite

Task 5 scaffolding (multi-dataset factor-graph pytree registration) must be complete.

jax.grad contract

Gradient is taken over the full parameter vector spanning all datasets. Assert finite and the
shape matches the combined free-parameter count.

Deliverables

  1. autogalaxy_workspace_test/scripts/jax_grad/multi/__init__.py
  2. Scripts above.
  3. Appended to smoke_tests.txt.

Depends on

Task 5 (multi-dataset pytree registration).

Umbrella issue

Task 8/9. Track under the epic issue on PyAutoLabs/autogalaxy_workspace_test.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions