You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Create scripts/jax_grad/interferometer/{lp.py, mge.py} in autogalaxy_workspace_test from scratch — autolens has no interferometer jax_grad reference. Each script exercises jax.value_and_grad on the autogalaxy AnalysisInterferometer likelihood path and asserts the gradient is finite, the shape matches the model's free-parameter count, and the gradient isn't all-zero. This is task 7/9 of the autogalaxy_workspace_test parity epic (#5).
Plan
Create scripts/jax_grad/interferometer/{__init__.py, lp.py, mge.py} from scratch (autolens has no interferometer jax_grad reference — this is greenfield).
Mirror the model/dataset setup of jax_likelihood_functions/interferometer/{lp,mge}.py exactly: same jax_test dataset, (256,256) real-space mask, pixel_scales=0.1, radius=3.0, TransformerDFT. Use plain ag.lp.Sersic (not lp_linear) for lp.py to match the validated JAX-likelihood pattern for interferometer.
Body of each script: jax.value_and_grad(fitness.call) with the same four assertions as task 6 (finite value, shape match, all-finite grad, not-all-zero grad).
Wire into smoke_tests.txt. The jax_grad/ env-vars override added in task 6 already covers jax_grad/interferometer/ (substring match), so no env_vars.yaml change.
Pre-existing dirt unrelated to this task. Worktree starts clean from origin/main (now at 4814b07 after task 6 / PR #29 merged).
Suggested branch:feature/autogalaxy-wst-jax-grad-interferometer Worktree root:~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-grad-interferometer/ (created later by /start_workspace)
scripts/jax_grad/interferometer/lp.py — mirror autogalaxy_workspace_test/scripts/jax_likelihood_functions/interferometer/lp.py for setup:
dataset_path = path.join("dataset", "interferometer", "jax_test"), with subprocess fallback to scripts/jax_likelihood_functions/interferometer/simulator.py.
Four assertions: finite value, grad.shape == (model.total_free_parameters,), all-finite grad, not-all-zero grad. Trailing print("jax_grad/interferometer/lp.py JAX gradient checks passed.").
scripts/jax_grad/interferometer/mge.py — same dataset/mask, model swapped to mge_model_from(mask_radius=3.0, total_gaussians=20, centre_prior_is_uniform=True) single-galaxy. Same body and assertions.
Append two lines to smoke_tests.txt immediately after jax_grad/imaging/mge.py:
lp.Sersic vs lp_linear.Sersic: task 6 used lp_linear.Sersic; this task uses plain lp.Sersic to match the validated jax_likelihood_functions/interferometer/lp.py pattern. The linear-inversion path under interferometer + JAX may exercise a different code path; sticking with the proven pattern.
DFT vs NUFFT transformer: using TransformerDFT (matching jax_likelihood_functions/interferometer/lp.py). NUFFT path is out of scope for the minimum-coverage deliverable.
Pytree prerequisite: _register_fit_interferometer_pytrees is in place at autogalaxy/interferometer/model/analysis.py:147 (shipped via task 4).
Existing jax_grad/ env override from task 6 already unsets PYAUTO_SMALL_DATASETS and PYAUTO_DISABLE_JAX for any path containing jax_grad/.
Original Prompt
Click to expand starting prompt
Create scripts/jax_grad/interferometer/ in @autogalaxy_workspace_test exercising jax.grad on
the autogalaxy interferometer likelihood path.
Layout note
Same subfolder-vs-flat divergence from autolens as task 6 — autogalaxy uses jax_grad/interferometer/. Flag in PR body; don't retrofit autolens without user go-ahead.
Scripts
autolens has no interferometer jax_grad scripts today. This task creates them from scratch
for autogalaxy, using the corresponding jax_likelihood_functions/interferometer/ scripts as
templates for model + dataset setup, then wrapping the likelihood in jax.grad.
Minimum coverage:
jax_grad/interferometer/lp.py
jax_grad/interferometer/mge.py
Additional variants only if the jax_likelihood_functions/interferometer/ task surfaced a
ready-to-grad path.
Skip: *_dspl.py.
Pytree prerequisite
Task 4 must have landed (AnalysisInterferometer pytree registration on PyAutoGalaxy).
jax.grad contract
Same as task 6 — finite gradient, correct free-parameter shape.
Overview
Create
scripts/jax_grad/interferometer/{lp.py, mge.py}in autogalaxy_workspace_test from scratch — autolens has no interferometerjax_gradreference. Each script exercisesjax.value_and_gradon the autogalaxyAnalysisInterferometerlikelihood path and asserts the gradient is finite, the shape matches the model's free-parameter count, and the gradient isn't all-zero. This is task 7/9 of the autogalaxy_workspace_test parity epic (#5).Plan
scripts/jax_grad/interferometer/{__init__.py, lp.py, mge.py}from scratch (autolens has no interferometerjax_gradreference — this is greenfield).jax_likelihood_functions/interferometer/{lp,mge}.pyexactly: samejax_testdataset,(256,256)real-space mask,pixel_scales=0.1,radius=3.0,TransformerDFT. Use plainag.lp.Sersic(notlp_linear) forlp.pyto match the validated JAX-likelihood pattern for interferometer.jax.value_and_grad(fitness.call)with the same four assertions as task 6 (finite value, shape match, all-finite grad, not-all-zero grad).smoke_tests.txt. Thejax_grad/env-vars override added in task 6 already coversjax_grad/interferometer/(substring match), so no env_vars.yaml change.Detailed implementation plan
Affected Repositories
Work Classification
Workspace
Branch Survey
dataset/*/jax_test/galaxies.json+ untrackedtest_report.md)origin/main(now at4814b07after task 6 / PR #29 merged).Suggested branch:
feature/autogalaxy-wst-jax-grad-interferometerWorktree root:
~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-grad-interferometer/(created later by/start_workspace)Implementation Steps
scripts/jax_grad/interferometer/__init__.py(empty namespace marker).scripts/jax_grad/interferometer/lp.py— mirrorautogalaxy_workspace_test/scripts/jax_likelihood_functions/interferometer/lp.pyfor setup:dataset_path = path.join("dataset", "interferometer", "jax_test"), with subprocess fallback toscripts/jax_likelihood_functions/interferometer/simulator.py.real_space_mask = ag.Mask2D.circular(shape_native=(256,256), pixel_scales=0.1, radius=3.0).dataset = ag.Interferometer.from_fits(... transformer_class=ag.TransformerDFT).bulge = af.Model(ag.lp.Sersic);galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge);model = af.Collection(galaxies=af.Collection(galaxy=galaxy)).analysis = ag.AnalysisInterferometer(dataset=dataset).Fitnesssetup,param_vector = jnp.array(model.physical_values_from_prior_medians), ell_comps perturbation,value, grad = jax.value_and_grad(fitness.call)(param_vector).grad.shape == (model.total_free_parameters,), all-finite grad, not-all-zero grad. Trailingprint("jax_grad/interferometer/lp.py JAX gradient checks passed.").scripts/jax_grad/interferometer/mge.py— same dataset/mask, model swapped tomge_model_from(mask_radius=3.0, total_gaussians=20, centre_prior_is_uniform=True)single-galaxy. Same body and assertions.smoke_tests.txtimmediately afterjax_grad/imaging/mge.py:jax_grad/substring override added in task 6 already matchesjax_grad/interferometer/.activate.sh, run each new script directly, then exercise via.github/scripts/run_smoke.py:run_one()to confirm CI behavior.Key Files
autogalaxy_workspace_test/scripts/jax_grad/interferometer/__init__.py— new (empty namespace marker).autogalaxy_workspace_test/scripts/jax_grad/interferometer/lp.py— new (Sersic single galaxy,TransformerDFT,jax.value_and_grad).autogalaxy_workspace_test/scripts/jax_grad/interferometer/mge.py— new (MGE single galaxy, same setup,jax.value_and_grad).autogalaxy_workspace_test/smoke_tests.txt— two new lines afterjax_grad/imaging/mge.py.Risks & Notes
lp.Sersicvslp_linear.Sersic: task 6 usedlp_linear.Sersic; this task uses plainlp.Sersicto match the validatedjax_likelihood_functions/interferometer/lp.pypattern. The linear-inversion path under interferometer + JAX may exercise a different code path; sticking with the proven pattern.TransformerDFT(matchingjax_likelihood_functions/interferometer/lp.py). NUFFT path is out of scope for the minimum-coverage deliverable._register_fit_interferometer_pytreesis in place atautogalaxy/interferometer/model/analysis.py:147(shipped via task 4).jax_grad/env override from task 6 already unsetsPYAUTO_SMALL_DATASETSandPYAUTO_DISABLE_JAXfor any path containingjax_grad/.Original Prompt
Click to expand starting prompt
Create
scripts/jax_grad/interferometer/in @autogalaxy_workspace_test exercisingjax.gradonthe autogalaxy interferometer likelihood path.
Layout note
Same subfolder-vs-flat divergence from autolens as task 6 — autogalaxy uses
jax_grad/interferometer/. Flag in PR body; don't retrofit autolens without user go-ahead.Scripts
autolens has no interferometer
jax_gradscripts today. This task creates them from scratchfor autogalaxy, using the corresponding
jax_likelihood_functions/interferometer/scripts astemplates for model + dataset setup, then wrapping the likelihood in
jax.grad.Minimum coverage:
jax_grad/interferometer/lp.pyjax_grad/interferometer/mge.pyAdditional variants only if the
jax_likelihood_functions/interferometer/task surfaced aready-to-grad path.
Skip:
*_dspl.py.Pytree prerequisite
Task 4 must have landed (
AnalysisInterferometerpytree registration on PyAutoGalaxy).jax.gradcontractSame as task 6 — finite gradient, correct free-parameter shape.
Deliverables
autogalaxy_workspace_test/scripts/jax_grad/interferometer/__init__.pysmoke_tests.txt.Depends on
Task 4 (pytree registration on PyAutoGalaxy interferometer analysis).
Umbrella issue
Task 7/9. Track under the epic issue on
PyAutoLabs/autogalaxy_workspace_test.