Overview
The jax_profiling/imaging/ suite has mature JIT profiling scripts for MGE,
pixelization, and Delaunay sources. The interferometer path — which uses a
different kernel (Fourier transform of the mapping matrix, visibilities-space
NNLS) — has no equivalent coverage. This task adds profiling scripts for the
MGE and pixelization source models on interferometer data, building on the
PyAutoFit#1222 / PyAutoArray#279 / PyAutoArray#282 pytree-and-NNLS
infrastructure already on main. Delaunay is deferred to a follow-up.
Plan
- Create
jax_profiling/interferometer/ mirroring the imaging suite layout.
- Add
simulators/interferometer.py so datasets auto-generate via
al.util.dataset.should_simulate + subprocess.run, keyed on instrument
presets (ALMA / SMA / …) at the top of each script.
- Phase 1 —
mge.py: MGE source over Isothermal + ExternalShear lens.
Per-step JIT profiling, full-pipeline JIT, jax.vmap batched throughput,
eager-vs-JIT numerical agreement, results JSON + PNG.
- Phase 2 —
pixelization.py: RectangularAdaptDensity source, same
profiling shape. Lands after Phase 1 is merged / validated.
- Delaunay is out of scope for this issue — will be a follow-up once the
interferometer pytree path is proven on MGE and pixelization.
- If the interferometer path fails to thread
xp=jnp / pytrees through a
step (NUFFT, dirty-image FFT, visibilities transform), file a separate
library-level issue rather than working around it in the script.
Detailed implementation plan
Affected Repositories
- autolens_workspace_developer (primary)
Work Classification
Workspace
Branch Survey
| Repository |
Current Branch |
Dirty? |
| autolens_workspace_developer |
main |
dirty (leftover from smoke-test-optimization — will work inside a worktree so it is preserved) |
Suggested branch: feature/interferometer-jax-profiling
Worktree root: ~/Code/PyAutoLabs-wt/interferometer-jax-profiling/ (created later by /start_workspace)
Phase 1 — simulators/interferometer.py + mge.py
-
Simulator — jax_profiling/interferometer/simulators/interferometer.py
- Instrument presets dict mirroring
simulators/imaging.py (sma, alma,
etc.), each with pixel_scale, uv_wavelengths fits path, exposure_time,
noise_sigma, transformer_class.
simulate(instrument) writes data.fits, noise_map.fits,
uv_wavelengths.fits, positions.json under
dataset/interferometer/<instrument>/ — pattern copied from
autolens_workspace_test/scripts/jax_likelihood_functions/interferometer/simulator.py
(real/imag stacked into FITS, PointSolver for positions).
- CLI entry:
python simulators/interferometer.py --instrument sma.
-
jax_profiling/interferometer/mge.py — full pattern from
jax_profiling/imaging/mge.py:
INSTRUMENTS dict at top, instrument = "sma" default.
- Auto-simulation via
should_simulate + subprocess.run of the
simulator.
- Eager baseline: build
FitInterferometer, print figure_of_merit /
log_likelihood.
- Per-step JIT profiling: ray-trace grids → mapping matrix →
transformed mapping matrix (visibilities) → data vector D → curvature
F (with Jacobi preconditioning, nnls_target_kappa=1e-2) → NNLS
reconstruction → mapped visibilities → chi-squared / log-likelihood.
- Full-pipeline JIT:
jax.jit(AnalysisInterferometer(dataset, use_jax=True).log_likelihood_function) on a pytree-registered
ModelInstance (via autofit.jax.register_model).
jax.vmap batched evaluation, per-call speedup print.
- Numerical assertion eager vs JIT at
rtol=1e-4.
- Results:
results/mge_likelihood_summary_<instrument>_v<al_version>.json
.png bar chart, schema matching the imaging counterpart.
Phase 2 — pixelization.py
Lands after Phase 1 merges. Same structure as imaging pixelization.py,
adapted for FitInterferometer + RectangularAdaptDensity source. Reuses
the Phase 1 simulator.
Key Files
jax_profiling/interferometer/simulators/interferometer.py — new
jax_profiling/interferometer/mge.py — new (Phase 1)
jax_profiling/interferometer/pixelization.py — new (Phase 2)
jax_profiling/interferometer/results/ — new (artefacts committed per imaging precedent)
Expected Blockers (per prompt)
- Interferometer path may not thread
xp=jnp through every step (visibilities
transform, dirty-image FFT) — file separate issue if hit.
jax.jit may balloon compile time for large visibilities — document and
suggest a batch_size knob.
- NUFFT transformer pytree-compatibility — flag as library-level blocker.
Original Prompt
Click to expand starting prompt
JAX JIT Profiling: Interferometer (MGE, Pixelization, Delaunay)
Context
autolens_workspace_developer/jax_profiling/imaging/ contains mature JAX
JIT profiling scripts for the three main source models:
mge.py — Multi-Gaussian Expansion source
pixelization.py — RectangularAdaptDensity pixelization source
delaunay.py — Delaunay pixelization source
Each script builds the log-likelihood function, applies jax.jit,
measures compile-time + per-eval runtime, and benchmarks against the
eager path. These scripts have driven many of the recent JAX
performance wins (e.g. the 80–96% runtime reductions in
smoke-test-optimization).
There is no equivalent coverage for interferometer datasets.
The FitInterferometer pipeline exercises a different kernel
(Fourier transform of the mapping matrix, visibilities-space
NNLS) whose JIT behaviour is not yet characterised.
Pytree infrastructure (already shipped — build on top of it)
These scripts target the full pytree approach:
jax.jit(AnalysisInterferometer.log_likelihood) on a real model with
all priors flowing as pytree leaves. Three pieces of library
infrastructure that make this viable have already landed on main:
- PyAutoFit#1222 —
TuplePrior is now registered as a JAX pytree.
This is the fix that raised the live JAX-leaf count on a typical
Isothermal+Shear+MGE model from 3 (only free floats) to 167 (every
prior in the model), so jax.value_and_grad actually flows through
the whole model rather than freezing most of it.
- PyAutoArray#279 — Jacobi preconditioning on the NNLS curvature
matrix. Mitigates ill-conditioning before the relaxed-KKT backward
pass runs.
- PyAutoArray#282 —
nnls_target_kappa=1.0e-2 config default
(was inheriting jaxnnls's 1e-3), which was producing NaN in the
NNLS backward pass on real MGE pipelines even with Jacobi
preconditioning.
You should not need to modify any library code to get these scripts
running — they depend on the above being present, and they provide
the signal if any of it regresses.
Task
Create three profiling scripts in
autolens_workspace_developer/jax_profiling/interferometer/:
mge.py — MGE source, Isothermal + ExternalShear lens.
pixelization.py — RectangularAdaptDensity source.
delaunay.py — Delaunay source.
Structure each script like its imaging counterpart:
- Dataset auto-simulation via
al.util.dataset.should_simulate +
subprocess.run on a matching simulators/interferometer.py.
Instrument presets (ALMA / SMA / …) keyed off an instrument
variable at the top, same pattern as imaging/mge.py.
- Eager baseline: build
FitInterferometer, print
figure_of_merit / log_likelihood.
- JAX path: wrap
Fitness.call in jax.jit, measure first-call
(compile) and steady-state runtimes over N repeats.
- vmap path: batch
batch_size parameter vectors through
fitness._vmap, measure per-likelihood cost.
- Assertion: numerical agreement between eager and JIT paths within
a sensible rtol.
- Output: write a results JSON + PNG summary into
jax_profiling/interferometer/results/ mirroring the imaging
results schema so they can be compared.
Dependencies
- An
interferometer dataset simulator must exist alongside the
imaging one. Check autolens_workspace_developer/jax_profiling/
for whether a simulator is already in place; if not, copy the
pattern from autolens_workspace_test/scripts/interferometer/simulator/.
al.FitInterferometer and al.AnalysisInterferometer are the
entry points; the xp=jnp path goes through
autoarray.inversion.inversion_interferometer.mapper_operator.
Expected output
Three working scripts that run end-to-end via:
cd jax_profiling/interferometer
python mge.py
python pixelization.py
python delaunay.py
Each producing a JIT vs eager timing comparison, a vmap batch
throughput measurement, and a results artefact for tracking.
Likely blockers to raise if encountered
- Interferometer path may not yet thread `xp=jnp` through every
step (visibilities transform, dirty-image FFT, etc.) — if a step
fails under JAX tracing, file a separate issue rather than
hacking around it.
- `jax.jit` may balloon compile time for large `visibilities` —
document and suggest a `batch_size` knob to split compilation.
- If the NUFFT transformer is not pytree-compatible, flag it as a
library-level blocker rather than working around it in the script.
Overview
The
jax_profiling/imaging/suite has mature JIT profiling scripts for MGE,pixelization, and Delaunay sources. The interferometer path — which uses a
different kernel (Fourier transform of the mapping matrix, visibilities-space
NNLS) — has no equivalent coverage. This task adds profiling scripts for the
MGE and pixelization source models on interferometer data, building on the
PyAutoFit#1222 / PyAutoArray#279 / PyAutoArray#282 pytree-and-NNLS
infrastructure already on
main. Delaunay is deferred to a follow-up.Plan
jax_profiling/interferometer/mirroring the imaging suite layout.simulators/interferometer.pyso datasets auto-generate viaal.util.dataset.should_simulate+subprocess.run, keyed on instrumentpresets (ALMA / SMA / …) at the top of each script.
mge.py: MGE source over Isothermal + ExternalShear lens.Per-step JIT profiling, full-pipeline JIT,
jax.vmapbatched throughput,eager-vs-JIT numerical agreement, results JSON + PNG.
pixelization.py: RectangularAdaptDensity source, sameprofiling shape. Lands after Phase 1 is merged / validated.
interferometer pytree path is proven on MGE and pixelization.
xp=jnp/ pytrees through astep (NUFFT, dirty-image FFT, visibilities transform), file a separate
library-level issue rather than working around it in the script.
Detailed implementation plan
Affected Repositories
Work Classification
Workspace
Branch Survey
Suggested branch:
feature/interferometer-jax-profilingWorktree root:
~/Code/PyAutoLabs-wt/interferometer-jax-profiling/(created later by/start_workspace)Phase 1 —
simulators/interferometer.py+mge.pySimulator —
jax_profiling/interferometer/simulators/interferometer.pysimulators/imaging.py(sma,alma,etc.), each with pixel_scale, uv_wavelengths fits path, exposure_time,
noise_sigma, transformer_class.
simulate(instrument)writesdata.fits,noise_map.fits,uv_wavelengths.fits,positions.jsonunderdataset/interferometer/<instrument>/— pattern copied fromautolens_workspace_test/scripts/jax_likelihood_functions/interferometer/simulator.py(real/imag stacked into FITS, PointSolver for positions).
python simulators/interferometer.py --instrument sma.jax_profiling/interferometer/mge.py— full pattern fromjax_profiling/imaging/mge.py:INSTRUMENTSdict at top,instrument = "sma"default.should_simulate+subprocess.runof thesimulator.
FitInterferometer, printfigure_of_merit/log_likelihood.transformed mapping matrix (visibilities) → data vector D → curvature
F (with Jacobi preconditioning,
nnls_target_kappa=1e-2) → NNLSreconstruction → mapped visibilities → chi-squared / log-likelihood.
jax.jit(AnalysisInterferometer(dataset, use_jax=True).log_likelihood_function)on a pytree-registeredModelInstance(viaautofit.jax.register_model).jax.vmapbatched evaluation, per-call speedup print.rtol=1e-4.results/mge_likelihood_summary_<instrument>_v<al_version>.json.pngbar chart, schema matching the imaging counterpart.Phase 2 —
pixelization.pyLands after Phase 1 merges. Same structure as imaging
pixelization.py,adapted for
FitInterferometer+RectangularAdaptDensitysource. Reusesthe Phase 1 simulator.
Key Files
jax_profiling/interferometer/simulators/interferometer.py— newjax_profiling/interferometer/mge.py— new (Phase 1)jax_profiling/interferometer/pixelization.py— new (Phase 2)jax_profiling/interferometer/results/— new (artefacts committed per imaging precedent)Expected Blockers (per prompt)
xp=jnpthrough every step (visibilitiestransform, dirty-image FFT) — file separate issue if hit.
jax.jitmay balloon compile time for largevisibilities— document andsuggest a
batch_sizeknob.Original Prompt
Click to expand starting prompt
JAX JIT Profiling: Interferometer (MGE, Pixelization, Delaunay)
Context
autolens_workspace_developer/jax_profiling/imaging/contains mature JAXJIT profiling scripts for the three main source models:
mge.py— Multi-Gaussian Expansion sourcepixelization.py— RectangularAdaptDensity pixelization sourcedelaunay.py— Delaunay pixelization sourceEach script builds the log-likelihood function, applies
jax.jit,measures compile-time + per-eval runtime, and benchmarks against the
eager path. These scripts have driven many of the recent JAX
performance wins (e.g. the 80–96% runtime reductions in
smoke-test-optimization).There is no equivalent coverage for interferometer datasets.
The
FitInterferometerpipeline exercises a different kernel(Fourier transform of the mapping matrix, visibilities-space
NNLS) whose JIT behaviour is not yet characterised.
Pytree infrastructure (already shipped — build on top of it)
These scripts target the full pytree approach:
jax.jit(AnalysisInterferometer.log_likelihood)on a real model withall priors flowing as pytree leaves. Three pieces of library
infrastructure that make this viable have already landed on
main:TuplePrioris now registered as a JAX pytree.This is the fix that raised the live JAX-leaf count on a typical
Isothermal+Shear+MGE model from 3 (only free floats) to 167 (every
prior in the model), so
jax.value_and_gradactually flows throughthe whole model rather than freezing most of it.
matrix. Mitigates ill-conditioning before the relaxed-KKT backward
pass runs.
nnls_target_kappa=1.0e-2config default(was inheriting jaxnnls's
1e-3), which was producing NaN in theNNLS backward pass on real MGE pipelines even with Jacobi
preconditioning.
You should not need to modify any library code to get these scripts
running — they depend on the above being present, and they provide
the signal if any of it regresses.
Task
Create three profiling scripts in
autolens_workspace_developer/jax_profiling/interferometer/:mge.py— MGE source, Isothermal + ExternalShear lens.pixelization.py— RectangularAdaptDensity source.delaunay.py— Delaunay source.Structure each script like its imaging counterpart:
al.util.dataset.should_simulate+subprocess.runon a matchingsimulators/interferometer.py.Instrument presets (ALMA / SMA / …) keyed off an
instrumentvariable at the top, same pattern as
imaging/mge.py.FitInterferometer, printfigure_of_merit/log_likelihood.Fitness.callinjax.jit, measure first-call(compile) and steady-state runtimes over N repeats.
batch_sizeparameter vectors throughfitness._vmap, measure per-likelihood cost.a sensible
rtol.jax_profiling/interferometer/results/mirroring the imagingresults schema so they can be compared.
Dependencies
interferometerdataset simulator must exist alongside theimaging one. Check
autolens_workspace_developer/jax_profiling/for whether a simulator is already in place; if not, copy the
pattern from
autolens_workspace_test/scripts/interferometer/simulator/.al.FitInterferometerandal.AnalysisInterferometerare theentry points; the xp=jnp path goes through
autoarray.inversion.inversion_interferometer.mapper_operator.Expected output
Three working scripts that run end-to-end via:
Each producing a JIT vs eager timing comparison, a vmap batch
throughput measurement, and a results artefact for tracking.
Likely blockers to raise if encountered
step (visibilities transform, dirty-image FFT, etc.) — if a step
fails under JAX tracing, file a separate issue rather than
hacking around it.
document and suggest a `batch_size` knob to split compilation.
library-level blocker rather than working around it in the script.