Skip to content

test: TransformerNUFFT cross-check (Path B) for interferometer JAX likelihood scripts#89

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/jax-likelihood-nufft-crosscheck
May 10, 2026
Merged

test: TransformerNUFFT cross-check (Path B) for interferometer JAX likelihood scripts#89
Jammy2211 merged 1 commit into
mainfrom
feature/jax-likelihood-nufft-crosscheck

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Adds a TransformerNUFFT cross-check (Path B) to each interferometer JAX likelihood script. Each script keeps its existing TransformerDFT path (Path A) untouched, then re-runs the same vmap likelihood with transformer_class=al.TransformerNUFFT and asserts the same hardcoded literal at rtol=1e-4. This proves end-to-end that the slow direct-DFT path and the fast nufftax-backed NUFFT path produce the same likelihood, catching any future drift between the two transformers.

Companion library PR

PyAutoArray PR #305perf: batched transform_mapping_matrix. The DSPL Path B in this workspace PR depends on that library fix to compile in seconds instead of minutes. The library PR should land first.

Per-script changes

  • lp.py / mge.py / mge_group.py / delaunay.py / rectangular.py / rectangular_dspl.py — standard Path B append, asserts the literal at rtol=1e-4.
  • delaunay_mge.py — rtol loosened to 2e-3. The Delaunay+MGE inversion amplifies the ~1e-13 forward-operator difference into a ~5e-4 relative shift in the final log-likelihood (likely mesh-vertex selection sensitivity).
  • rectangular_mge.py / rectangular_dspl.py — include gc.collect(), jax.clear_caches(), and parameters_nufft = parameters[:1] before Path B to keep peak memory within the 16 GB CI box.
  • rectangular_sparse.pythree-way cross-check: DFT+sparse_operator (existing literal) → DFT-no-sparse → TransformerNUFFT (no sparse). The latter two assert against rectangular.py's canonical -3164.286252 literal because apply_sparse_operator gives a numerically distinct ~0.4% off result from the bare DFT, hence the path-specific literals.

DSPL mesh reduction

Both interferometer/rectangular_dspl.py and imaging/rectangular_dspl.py reduce mesh_shape from (30, 30) to (8, 8) to match the resolution the other rectangular*.py tests use. Test scripts should be lightweight; (30, 30) produced 1800-pixel JIT traces and slow compiles. New canonical literals captured empirically:

  • interferometer/rectangular_dspl.py: -3170.19672623
  • imaging/rectangular_dspl.py: -3797.73182794

Scripts Changed

  • scripts/jax_likelihood_functions/interferometer/{lp,mge,mge_group,delaunay,delaunay_mge,rectangular,rectangular_mge,rectangular_dspl,rectangular_sparse}.py
  • scripts/jax_likelihood_functions/imaging/rectangular_dspl.py

Test plan

  • Every interferometer script: python <script>.py ends with PASS: TransformerNUFFT cross-check matches TransformerDFT.
  • imaging/rectangular_dspl.py: passes the existing assertion at the new literal.
  • DSPL Path B compiles in ~30s (down from 11+ min) with the library fix.

🤖 Generated with Claude Code

…kelihood scripts

Each interferometer JAX likelihood script now appends a Path B that re-runs
the same vmap likelihood with `transformer_class=TransformerNUFFT` (the new
nufftax-backed default) and asserts the same hardcoded literal as the
TransformerDFT path. This proves end-to-end that the slow direct-DFT and
fast nufftax-NUFFT paths produce the same likelihood, catching any future
drift between the two transformers.

Per-script:
- lp.py / mge.py / mge_group.py / delaunay.py / rectangular.py /
  rectangular_dspl.py: standard Path B append, asserts literal at rtol=1e-4.
- delaunay_mge.py: rtol loosened to 2e-3. The Delaunay+MGE inversion
  amplifies the ~1e-13 numerical difference between DFT and nufftax in
  the forward operator into a ~5e-4 relative shift in the final
  log-likelihood (likely via mesh-vertex selection sensitivity).
- rectangular_mge.py / rectangular_dspl.py: include `gc.collect()`,
  `jax.clear_caches()`, and `parameters_nufft = parameters[:1]` before
  Path B to keep peak memory within the 16 GB CI box.
- rectangular_sparse.py: three-way cross-check — DFT+sparse_operator
  (existing literal) -> DFT-no-sparse and TransformerNUFFT (no sparse).
  The latter two assert against rectangular.py's canonical -3164.286252
  literal (apply_sparse_operator gives a numerically distinct ~0.4%
  off result from the bare DFT, hence the path-specific literals).

DSPL mesh reduction:
- Both interferometer/rectangular_dspl.py and imaging/rectangular_dspl.py
  reduce mesh_shape from (30, 30) to (8, 8) to match the resolution other
  rectangular tests use. Test scripts should be lightweight; (30, 30)
  produced 1800-pixel JIT traces and slow compiles. New literals captured
  empirically: -3170.19672623 (interferometer), -3797.73182794 (imaging).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 056aac0 into main May 10, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/jax-likelihood-nufft-crosscheck branch May 10, 2026 13:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant