You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Create scripts/jax_grad/multi/{lp.py, mge.py} in autogalaxy_workspace_test from scratch — autolens has no multi jax_grad reference. Each script wraps an af.FactorGraphModel over per-band AnalysisImaging factors in jax.value_and_grad and asserts the gradient is finite, the shape matches the joint global parameter vector spanning g and r bands, and the gradient isn't all-zero. This is the final task (8/9) of the autogalaxy_workspace_test parity epic (#5).
Plan
Create scripts/jax_grad/multi/{__init__.py, lp.py, mge.py} from scratch — autolens has no multi jax_grad reference. The multi pipeline uses af.FactorGraphModel + per-band AnalysisImaging factors with option-B per-band ell_comps.
Each script wraps the FactorGraphModel.log_likelihood (via Fitness(model=factor_graph.global_prior_model, analysis=factor_graph)) in jax.value_and_grad. Gradient runs over the joint global parameter vector spanning both g and r bands.
Assertions match tasks 6 and 7: finite value, shape matches factor_graph.global_prior_model.total_free_parameters, all-finite grad, not-all-zero grad.
Pre-existing dirt unrelated to this task. Worktree starts clean from origin/main (now at 4630b7c after task 7 / PR #31 merged).
Suggested branch:feature/autogalaxy-wst-jax-grad-multi Worktree root:~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-grad-multi/ (created later by /start_workspace)
Local verify each via .github/scripts/run_smoke.py:run_one().
Key Files
autogalaxy_workspace_test/scripts/jax_grad/multi/__init__.py — new (empty namespace marker).
autogalaxy_workspace_test/scripts/jax_grad/multi/lp.py — new (Sersic, multi-band FactorGraph, jax.value_and_grad).
autogalaxy_workspace_test/scripts/jax_grad/multi/mge.py — new (MGE, multi-band FactorGraph, jax.value_and_grad).
autogalaxy_workspace_test/smoke_tests.txt — two new lines after jax_grad/interferometer/mge.py.
Risks & Notes
Multi pipeline JAX cost: factor-graph jax.value_and_grad may be heavier than single-band — each grad call traces both band graphs. Acceptable if completes in <30s. The matching jax_likelihood_functions/multi/lp.py script runs in similar time (vmap is similarly expensive).
lp.Sersic (not lp_linear): keeping the validated jax_likelihood_functions/multi/lp.py pattern. Linear path under multi+JAX hasn't been validated for grad.
autolens layout divergence (final): tasks 6/7/8 establish the subfolder convention. After this PR ships, suggest filing the autolens-retrofit follow-up issue covering all three subdirectories.
Overview
Create
scripts/jax_grad/multi/{lp.py, mge.py}in autogalaxy_workspace_test from scratch — autolens has no multijax_gradreference. Each script wraps anaf.FactorGraphModelover per-bandAnalysisImagingfactors injax.value_and_gradand asserts the gradient is finite, the shape matches the joint global parameter vector spanning g and r bands, and the gradient isn't all-zero. This is the final task (8/9) of the autogalaxy_workspace_test parity epic (#5).Plan
scripts/jax_grad/multi/{__init__.py, lp.py, mge.py}from scratch — autolens has no multijax_gradreference. The multi pipeline usesaf.FactorGraphModel+ per-bandAnalysisImagingfactors with option-B per-bandell_comps.FactorGraphModel.log_likelihood(viaFitness(model=factor_graph.global_prior_model, analysis=factor_graph)) injax.value_and_grad. Gradient runs over the joint global parameter vector spanning bothgandrbands.factor_graph.global_prior_model.total_free_parameters, all-finite grad, not-all-zero grad.smoke_tests.txtafterjax_grad/interferometer/mge.py. Thejax_grad/env override from PR feat: scripts/jax_grad/imaging/ — lp.py + mge.py port (task 6/9) #29 already coversjax_grad/multi/.Detailed implementation plan
Affected Repositories
Work Classification
Workspace
Branch Survey
dataset/*/jax_test/galaxies.json+ untrackedtest_report.md)origin/main(now at4630b7cafter task 7 / PR #31 merged).Suggested branch:
feature/autogalaxy-wst-jax-grad-multiWorktree root:
~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-grad-multi/(created later by/start_workspace)Implementation Steps
scripts/jax_grad/multi/__init__.py(empty).scripts/jax_grad/multi/lp.py— mirrorjax_likelihood_functions/multi/lp.pyfor setup:waveband_list = ["g", "r"],pixel_scales=0.1,mask_radius=3.0,dataset/multi/jax_test/{g,r}_{data,psf,noise_map}.fits, with subprocess fallback toscripts/jax_likelihood_functions/multi/simulator.py.Imaging.from_fits→apply_mask→apply_over_sampling(over_sample_size_lp=1).bulge = af.Model(ag.lp.Sersic);galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge);model = af.Collection(galaxies=af.Collection(galaxy=galaxy)).model_per_band_list = [model.copy() with galaxies.galaxy.bulge.ell_comps.ell_comps_0/1 = af.GaussianPrior(mean=0.0, sigma=0.5)].AnalysisImagingper band →AnalysisFactorper band →FactorGraphModel(*..., use_jax=True).Fitness(model=factor_graph.global_prior_model, analysis=factor_graph, fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99).param_vector = jnp.array(factor_graph.global_prior_model.physical_values_from_prior_medians), ell_comps perturbation,value, grad = jax.value_and_grad(fitness.call)(param_vector).factor_graph.global_prior_model.total_free_parameters. Trailingprint("jax_grad/multi/lp.py JAX gradient checks passed.").scripts/jax_grad/multi/mge.py— same multi-band scaffold, MGE bulge instead of Sersic. Per-band ell_comps re-tying loop fromjax_likelihood_functions/multi/mge.py:jax.value_and_gradbody aslp.py.smoke_tests.txtimmediately afterjax_grad/interferometer/mge.py:jax_grad/substring override added in PR feat: scripts/jax_grad/imaging/ — lp.py + mge.py port (task 6/9) #29 already matchesjax_grad/multi/..github/scripts/run_smoke.py:run_one().Key Files
autogalaxy_workspace_test/scripts/jax_grad/multi/__init__.py— new (empty namespace marker).autogalaxy_workspace_test/scripts/jax_grad/multi/lp.py— new (Sersic, multi-band FactorGraph,jax.value_and_grad).autogalaxy_workspace_test/scripts/jax_grad/multi/mge.py— new (MGE, multi-band FactorGraph,jax.value_and_grad).autogalaxy_workspace_test/smoke_tests.txt— two new lines afterjax_grad/interferometer/mge.py.Risks & Notes
jax.value_and_gradmay be heavier than single-band — each grad call traces both band graphs. Acceptable if completes in <30s. The matchingjax_likelihood_functions/multi/lp.pyscript runs in similar time (vmap is similarly expensive).lp.Sersic(notlp_linear): keeping the validatedjax_likelihood_functions/multi/lp.pypattern. Linear path under multi+JAX hasn't been validated for grad.factor_graph.global_prior_modelregistration scaffolding from task 5 / PR feat: jax_likelihood_functions/multi/ port #19 underpins the multi grad path.jax_grad/env override from PR feat: scripts/jax_grad/imaging/ — lp.py + mge.py port (task 6/9) #29 already unsetsPYAUTO_SMALL_DATASETSandPYAUTO_DISABLE_JAXfor any path containingjax_grad/.Original Prompt
Click to expand starting prompt
Create
scripts/jax_grad/multi/in @autogalaxy_workspace_test exercisingjax.gradon theautogalaxy multi-dataset likelihood path.
Layout note
Same subfolder-vs-flat divergence as tasks 6 and 7. Flag in PR body.
Scripts
autolens has no multi
jax_gradscripts today. Create from scratch using thejax_likelihood_functions/multi/templates (task 5).Minimum coverage:
jax_grad/multi/lp.pyjax_grad/multi/mge.pyAdditional variants only if feasible.
Skip:
*_dspl.py.Pytree prerequisite
Task 5 scaffolding (multi-dataset factor-graph pytree registration) must be complete.
jax.gradcontractGradient is taken over the full parameter vector spanning all datasets. Assert finite and the
shape matches the combined free-parameter count.
Deliverables
autogalaxy_workspace_test/scripts/jax_grad/multi/__init__.pysmoke_tests.txt.Depends on
Task 5 (multi-dataset pytree registration).
Umbrella issue
Task 8/9. Track under the epic issue on
PyAutoLabs/autogalaxy_workspace_test.