test: TransformerNUFFT cross-check (Path B) for interferometer JAX likelihood scripts#89
Merged
Merged
Conversation
…kelihood scripts Each interferometer JAX likelihood script now appends a Path B that re-runs the same vmap likelihood with `transformer_class=TransformerNUFFT` (the new nufftax-backed default) and asserts the same hardcoded literal as the TransformerDFT path. This proves end-to-end that the slow direct-DFT and fast nufftax-NUFFT paths produce the same likelihood, catching any future drift between the two transformers. Per-script: - lp.py / mge.py / mge_group.py / delaunay.py / rectangular.py / rectangular_dspl.py: standard Path B append, asserts literal at rtol=1e-4. - delaunay_mge.py: rtol loosened to 2e-3. The Delaunay+MGE inversion amplifies the ~1e-13 numerical difference between DFT and nufftax in the forward operator into a ~5e-4 relative shift in the final log-likelihood (likely via mesh-vertex selection sensitivity). - rectangular_mge.py / rectangular_dspl.py: include `gc.collect()`, `jax.clear_caches()`, and `parameters_nufft = parameters[:1]` before Path B to keep peak memory within the 16 GB CI box. - rectangular_sparse.py: three-way cross-check — DFT+sparse_operator (existing literal) -> DFT-no-sparse and TransformerNUFFT (no sparse). The latter two assert against rectangular.py's canonical -3164.286252 literal (apply_sparse_operator gives a numerically distinct ~0.4% off result from the bare DFT, hence the path-specific literals). DSPL mesh reduction: - Both interferometer/rectangular_dspl.py and imaging/rectangular_dspl.py reduce mesh_shape from (30, 30) to (8, 8) to match the resolution other rectangular tests use. Test scripts should be lightweight; (30, 30) produced 1800-pixel JIT traces and slow compiles. New literals captured empirically: -3170.19672623 (interferometer), -3797.73182794 (imaging). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a TransformerNUFFT cross-check (Path B) to each interferometer JAX likelihood script. Each script keeps its existing TransformerDFT path (Path A) untouched, then re-runs the same vmap likelihood with
transformer_class=al.TransformerNUFFTand asserts the same hardcoded literal atrtol=1e-4. This proves end-to-end that the slow direct-DFT path and the fast nufftax-backed NUFFT path produce the same likelihood, catching any future drift between the two transformers.Companion library PR
PyAutoArray PR #305 —
perf: batched transform_mapping_matrix. The DSPL Path B in this workspace PR depends on that library fix to compile in seconds instead of minutes. The library PR should land first.Per-script changes
lp.py/mge.py/mge_group.py/delaunay.py/rectangular.py/rectangular_dspl.py— standard Path B append, asserts the literal atrtol=1e-4.delaunay_mge.py— rtol loosened to2e-3. The Delaunay+MGE inversion amplifies the ~1e-13 forward-operator difference into a ~5e-4 relative shift in the final log-likelihood (likely mesh-vertex selection sensitivity).rectangular_mge.py/rectangular_dspl.py— includegc.collect(),jax.clear_caches(), andparameters_nufft = parameters[:1]before Path B to keep peak memory within the 16 GB CI box.rectangular_sparse.py— three-way cross-check: DFT+sparse_operator (existing literal) → DFT-no-sparse → TransformerNUFFT (no sparse). The latter two assert againstrectangular.py's canonical-3164.286252literal becauseapply_sparse_operatorgives a numerically distinct ~0.4% off result from the bare DFT, hence the path-specific literals.DSPL mesh reduction
Both
interferometer/rectangular_dspl.pyandimaging/rectangular_dspl.pyreducemesh_shapefrom(30, 30)to(8, 8)to match the resolution the otherrectangular*.pytests use. Test scripts should be lightweight;(30, 30)produced 1800-pixel JIT traces and slow compiles. New canonical literals captured empirically:interferometer/rectangular_dspl.py:-3170.19672623imaging/rectangular_dspl.py:-3797.73182794Scripts Changed
scripts/jax_likelihood_functions/interferometer/{lp,mge,mge_group,delaunay,delaunay_mge,rectangular,rectangular_mge,rectangular_dspl,rectangular_sparse}.pyscripts/jax_likelihood_functions/imaging/rectangular_dspl.pyTest plan
python <script>.pyends withPASS: TransformerNUFFT cross-check matches TransformerDFT.imaging/rectangular_dspl.py: passes the existing assertion at the new literal.🤖 Generated with Claude Code