Skip to content

feat: scripts/jax_grad/interferometer/ — lp.py + mge.py from scratch (task 7/9) #30

@Jammy2211

Description

@Jammy2211

Overview

Create scripts/jax_grad/interferometer/{lp.py, mge.py} in autogalaxy_workspace_test from scratch — autolens has no interferometer jax_grad reference. Each script exercises jax.value_and_grad on the autogalaxy AnalysisInterferometer likelihood path and asserts the gradient is finite, the shape matches the model's free-parameter count, and the gradient isn't all-zero. This is task 7/9 of the autogalaxy_workspace_test parity epic (#5).

Plan

  • Create scripts/jax_grad/interferometer/{__init__.py, lp.py, mge.py} from scratch (autolens has no interferometer jax_grad reference — this is greenfield).
  • Mirror the model/dataset setup of jax_likelihood_functions/interferometer/{lp,mge}.py exactly: same jax_test dataset, (256,256) real-space mask, pixel_scales=0.1, radius=3.0, TransformerDFT. Use plain ag.lp.Sersic (not lp_linear) for lp.py to match the validated JAX-likelihood pattern for interferometer.
  • Body of each script: jax.value_and_grad(fitness.call) with the same four assertions as task 6 (finite value, shape match, all-finite grad, not-all-zero grad).
  • Wire into smoke_tests.txt. The jax_grad/ env-vars override added in task 6 already covers jax_grad/interferometer/ (substring match), so no env_vars.yaml change.
Detailed implementation plan

Affected Repositories

  • autogalaxy_workspace_test (primary)

Work Classification

Workspace

Branch Survey

Repository Current Branch Dirty? Notes
./autogalaxy_workspace_test main dirty (3 modified dataset/*/jax_test/galaxies.json + untracked test_report.md) Pre-existing dirt unrelated to this task. Worktree starts clean from origin/main (now at 4814b07 after task 6 / PR #29 merged).

Suggested branch: feature/autogalaxy-wst-jax-grad-interferometer
Worktree root: ~/Code/PyAutoLabs-wt/autogalaxy-wst-jax-grad-interferometer/ (created later by /start_workspace)

Implementation Steps

  1. Create scripts/jax_grad/interferometer/__init__.py (empty namespace marker).
  2. scripts/jax_grad/interferometer/lp.py — mirror autogalaxy_workspace_test/scripts/jax_likelihood_functions/interferometer/lp.py for setup:
    • dataset_path = path.join("dataset", "interferometer", "jax_test"), with subprocess fallback to scripts/jax_likelihood_functions/interferometer/simulator.py.
    • real_space_mask = ag.Mask2D.circular(shape_native=(256,256), pixel_scales=0.1, radius=3.0).
    • dataset = ag.Interferometer.from_fits(... transformer_class=ag.TransformerDFT).
    • Model: bulge = af.Model(ag.lp.Sersic); galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge); model = af.Collection(galaxies=af.Collection(galaxy=galaxy)).
    • analysis = ag.AnalysisInterferometer(dataset=dataset).
    • Fitness setup, param_vector = jnp.array(model.physical_values_from_prior_medians), ell_comps perturbation, value, grad = jax.value_and_grad(fitness.call)(param_vector).
    • Four assertions: finite value, grad.shape == (model.total_free_parameters,), all-finite grad, not-all-zero grad. Trailing print("jax_grad/interferometer/lp.py JAX gradient checks passed.").
  3. scripts/jax_grad/interferometer/mge.py — same dataset/mask, model swapped to mge_model_from(mask_radius=3.0, total_gaussians=20, centre_prior_is_uniform=True) single-galaxy. Same body and assertions.
  4. Append two lines to smoke_tests.txt immediately after jax_grad/imaging/mge.py:
    jax_grad/interferometer/lp.py
    jax_grad/interferometer/mge.py
    
  5. No env_vars.yaml change — the jax_grad/ substring override added in task 6 already matches jax_grad/interferometer/.
  6. Local verify: source activate.sh, run each new script directly, then exercise via .github/scripts/run_smoke.py:run_one() to confirm CI behavior.

Key Files

  • autogalaxy_workspace_test/scripts/jax_grad/interferometer/__init__.py — new (empty namespace marker).
  • autogalaxy_workspace_test/scripts/jax_grad/interferometer/lp.py — new (Sersic single galaxy, TransformerDFT, jax.value_and_grad).
  • autogalaxy_workspace_test/scripts/jax_grad/interferometer/mge.py — new (MGE single galaxy, same setup, jax.value_and_grad).
  • autogalaxy_workspace_test/smoke_tests.txt — two new lines after jax_grad/imaging/mge.py.

Risks & Notes

  • autolens layout divergence (continued): tasks 6/7/8 all establish the subfolder convention on autogalaxy. The retrofit question was already raised in PR feat: scripts/jax_grad/imaging/ — lp.py + mge.py port (task 6/9) #29 (task 6) — no need to re-litigate.
  • lp.Sersic vs lp_linear.Sersic: task 6 used lp_linear.Sersic; this task uses plain lp.Sersic to match the validated jax_likelihood_functions/interferometer/lp.py pattern. The linear-inversion path under interferometer + JAX may exercise a different code path; sticking with the proven pattern.
  • DFT vs NUFFT transformer: using TransformerDFT (matching jax_likelihood_functions/interferometer/lp.py). NUFFT path is out of scope for the minimum-coverage deliverable.
  • Pytree prerequisite: _register_fit_interferometer_pytrees is in place at autogalaxy/interferometer/model/analysis.py:147 (shipped via task 4).
  • Existing jax_grad/ env override from task 6 already unsets PYAUTO_SMALL_DATASETS and PYAUTO_DISABLE_JAX for any path containing jax_grad/.

Original Prompt

Click to expand starting prompt

Create scripts/jax_grad/interferometer/ in @autogalaxy_workspace_test exercising jax.grad on
the autogalaxy interferometer likelihood path.

Layout note

Same subfolder-vs-flat divergence from autolens as task 6 — autogalaxy uses
jax_grad/interferometer/. Flag in PR body; don't retrofit autolens without user go-ahead.

Scripts

autolens has no interferometer jax_grad scripts today. This task creates them from scratch
for autogalaxy, using the corresponding jax_likelihood_functions/interferometer/ scripts as
templates for model + dataset setup, then wrapping the likelihood in jax.grad.

Minimum coverage:

  • jax_grad/interferometer/lp.py
  • jax_grad/interferometer/mge.py

Additional variants only if the jax_likelihood_functions/interferometer/ task surfaced a
ready-to-grad path.

Skip: *_dspl.py.

Pytree prerequisite

Task 4 must have landed (AnalysisInterferometer pytree registration on PyAutoGalaxy).

jax.grad contract

Same as task 6 — finite gradient, correct free-parameter shape.

Deliverables

  1. autogalaxy_workspace_test/scripts/jax_grad/interferometer/__init__.py
  2. Scripts above.
  3. Appended to smoke_tests.txt.

Depends on

Task 4 (pytree registration on PyAutoGalaxy interferometer analysis).

Umbrella issue

Task 7/9. Track under the epic issue on PyAutoLabs/autogalaxy_workspace_test.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions