Skip to content

feat: JAX JIT profiling for interferometer (MGE + pixelization) #14

@Jammy2211

Description

@Jammy2211

Overview

The jax_profiling/imaging/ suite has mature JIT profiling scripts for MGE,
pixelization, and Delaunay sources. The interferometer path — which uses a
different kernel (Fourier transform of the mapping matrix, visibilities-space
NNLS) — has no equivalent coverage. This task adds profiling scripts for the
MGE and pixelization source models on interferometer data, building on the
PyAutoFit#1222 / PyAutoArray#279 / PyAutoArray#282 pytree-and-NNLS
infrastructure already on main. Delaunay is deferred to a follow-up.

Plan

  • Create jax_profiling/interferometer/ mirroring the imaging suite layout.
  • Add simulators/interferometer.py so datasets auto-generate via
    al.util.dataset.should_simulate + subprocess.run, keyed on instrument
    presets (ALMA / SMA / …) at the top of each script.
  • Phase 1 — mge.py: MGE source over Isothermal + ExternalShear lens.
    Per-step JIT profiling, full-pipeline JIT, jax.vmap batched throughput,
    eager-vs-JIT numerical agreement, results JSON + PNG.
  • Phase 2 — pixelization.py: RectangularAdaptDensity source, same
    profiling shape. Lands after Phase 1 is merged / validated.
  • Delaunay is out of scope for this issue — will be a follow-up once the
    interferometer pytree path is proven on MGE and pixelization.
  • If the interferometer path fails to thread xp=jnp / pytrees through a
    step (NUFFT, dirty-image FFT, visibilities transform), file a separate
    library-level issue rather than working around it in the script.
Detailed implementation plan

Affected Repositories

  • autolens_workspace_developer (primary)

Work Classification

Workspace

Branch Survey

Repository Current Branch Dirty?
autolens_workspace_developer main dirty (leftover from smoke-test-optimization — will work inside a worktree so it is preserved)

Suggested branch: feature/interferometer-jax-profiling
Worktree root: ~/Code/PyAutoLabs-wt/interferometer-jax-profiling/ (created later by /start_workspace)

Phase 1 — simulators/interferometer.py + mge.py

  1. Simulatorjax_profiling/interferometer/simulators/interferometer.py

    • Instrument presets dict mirroring simulators/imaging.py (sma, alma,
      etc.), each with pixel_scale, uv_wavelengths fits path, exposure_time,
      noise_sigma, transformer_class.
    • simulate(instrument) writes data.fits, noise_map.fits,
      uv_wavelengths.fits, positions.json under
      dataset/interferometer/<instrument>/ — pattern copied from
      autolens_workspace_test/scripts/jax_likelihood_functions/interferometer/simulator.py
      (real/imag stacked into FITS, PointSolver for positions).
    • CLI entry: python simulators/interferometer.py --instrument sma.
  2. jax_profiling/interferometer/mge.py — full pattern from
    jax_profiling/imaging/mge.py:

    • INSTRUMENTS dict at top, instrument = "sma" default.
    • Auto-simulation via should_simulate + subprocess.run of the
      simulator.
    • Eager baseline: build FitInterferometer, print figure_of_merit /
      log_likelihood.
    • Per-step JIT profiling: ray-trace grids → mapping matrix →
      transformed mapping matrix (visibilities) → data vector D → curvature
      F (with Jacobi preconditioning, nnls_target_kappa=1e-2) → NNLS
      reconstruction → mapped visibilities → chi-squared / log-likelihood.
    • Full-pipeline JIT: jax.jit(AnalysisInterferometer(dataset, use_jax=True).log_likelihood_function) on a pytree-registered
      ModelInstance (via autofit.jax.register_model).
    • jax.vmap batched evaluation, per-call speedup print.
    • Numerical assertion eager vs JIT at rtol=1e-4.
    • Results: results/mge_likelihood_summary_<instrument>_v<al_version>.json
      • .png bar chart, schema matching the imaging counterpart.

Phase 2 — pixelization.py

Lands after Phase 1 merges. Same structure as imaging pixelization.py,
adapted for FitInterferometer + RectangularAdaptDensity source. Reuses
the Phase 1 simulator.

Key Files

  • jax_profiling/interferometer/simulators/interferometer.py — new
  • jax_profiling/interferometer/mge.py — new (Phase 1)
  • jax_profiling/interferometer/pixelization.py — new (Phase 2)
  • jax_profiling/interferometer/results/ — new (artefacts committed per imaging precedent)

Expected Blockers (per prompt)

  • Interferometer path may not thread xp=jnp through every step (visibilities
    transform, dirty-image FFT) — file separate issue if hit.
  • jax.jit may balloon compile time for large visibilities — document and
    suggest a batch_size knob.
  • NUFFT transformer pytree-compatibility — flag as library-level blocker.

Original Prompt

Click to expand starting prompt

JAX JIT Profiling: Interferometer (MGE, Pixelization, Delaunay)

Context

autolens_workspace_developer/jax_profiling/imaging/ contains mature JAX
JIT profiling scripts for the three main source models:

  • mge.py — Multi-Gaussian Expansion source
  • pixelization.py — RectangularAdaptDensity pixelization source
  • delaunay.py — Delaunay pixelization source

Each script builds the log-likelihood function, applies jax.jit,
measures compile-time + per-eval runtime, and benchmarks against the
eager path. These scripts have driven many of the recent JAX
performance wins (e.g. the 80–96% runtime reductions in
smoke-test-optimization).

There is no equivalent coverage for interferometer datasets.
The FitInterferometer pipeline exercises a different kernel
(Fourier transform of the mapping matrix, visibilities-space
NNLS) whose JIT behaviour is not yet characterised.

Pytree infrastructure (already shipped — build on top of it)

These scripts target the full pytree approach:
jax.jit(AnalysisInterferometer.log_likelihood) on a real model with
all priors flowing as pytree leaves. Three pieces of library
infrastructure that make this viable have already landed on main:

  • PyAutoFit#1222TuplePrior is now registered as a JAX pytree.
    This is the fix that raised the live JAX-leaf count on a typical
    Isothermal+Shear+MGE model from 3 (only free floats) to 167 (every
    prior in the model), so jax.value_and_grad actually flows through
    the whole model rather than freezing most of it.
  • PyAutoArray#279 — Jacobi preconditioning on the NNLS curvature
    matrix. Mitigates ill-conditioning before the relaxed-KKT backward
    pass runs.
  • PyAutoArray#282nnls_target_kappa=1.0e-2 config default
    (was inheriting jaxnnls's 1e-3), which was producing NaN in the
    NNLS backward pass on real MGE pipelines even with Jacobi
    preconditioning.

You should not need to modify any library code to get these scripts
running — they depend on the above being present, and they provide
the signal if any of it regresses.

Task

Create three profiling scripts in
autolens_workspace_developer/jax_profiling/interferometer/:

  1. mge.py — MGE source, Isothermal + ExternalShear lens.
  2. pixelization.py — RectangularAdaptDensity source.
  3. delaunay.py — Delaunay source.

Structure each script like its imaging counterpart:

  • Dataset auto-simulation via al.util.dataset.should_simulate +
    subprocess.run on a matching simulators/interferometer.py.
    Instrument presets (ALMA / SMA / …) keyed off an instrument
    variable at the top, same pattern as imaging/mge.py.
  • Eager baseline: build FitInterferometer, print
    figure_of_merit / log_likelihood.
  • JAX path: wrap Fitness.call in jax.jit, measure first-call
    (compile) and steady-state runtimes over N repeats.
  • vmap path: batch batch_size parameter vectors through
    fitness._vmap, measure per-likelihood cost.
  • Assertion: numerical agreement between eager and JIT paths within
    a sensible rtol.
  • Output: write a results JSON + PNG summary into
    jax_profiling/interferometer/results/ mirroring the imaging
    results schema so they can be compared.

Dependencies

  • An interferometer dataset simulator must exist alongside the
    imaging one. Check autolens_workspace_developer/jax_profiling/
    for whether a simulator is already in place; if not, copy the
    pattern from autolens_workspace_test/scripts/interferometer/simulator/.
  • al.FitInterferometer and al.AnalysisInterferometer are the
    entry points; the xp=jnp path goes through
    autoarray.inversion.inversion_interferometer.mapper_operator.

Expected output

Three working scripts that run end-to-end via:

cd jax_profiling/interferometer
python mge.py
python pixelization.py
python delaunay.py

Each producing a JIT vs eager timing comparison, a vmap batch
throughput measurement, and a results artefact for tracking.

Likely blockers to raise if encountered

  • Interferometer path may not yet thread `xp=jnp` through every
    step (visibilities transform, dirty-image FFT, etc.) — if a step
    fails under JAX tracing, file a separate issue rather than
    hacking around it.
  • `jax.jit` may balloon compile time for large `visibilities` —
    document and suggest a `batch_size` knob to split compilation.
  • If the NUFFT transformer is not pytree-compatible, flag it as a
    library-level blocker rather than working around it in the script.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions