Make use_mixed_precision actually emit fp32 FFT for light profiles#302
Merged
Make use_mixed_precision actually emit fp32 FFT for light profiles#302
Conversation
The flag previously forced fp64 FFT internally in convolved_image_from and only downcast at the end -- a net loss on consumer GPUs (mp full pipeline 27% slower than fp64 on RTX 2060). The light-profile FFT path now runs end-to-end complex64, with the kernel pre-cached on ConvolverState.fft_kernel_c64 to keep the per-call astype out of CPU profiles. convolved_mapping_matrix_from intentionally keeps the fp64 kernel multiply to preserve pixelization figure_of_merit precision: full fp32 in that path caused 1.9% relative drift on the autolens_workspace_test delaunay_mge regression (K=780 source mesh). The fp32 input cube and forward rfft2 are kept for the cheaper scatter and FFT, but the multiply upcasts back to complex128 and the irfft2 returns fp64. Empirical impact (RTX 2060 + i9-10885H, mge.py HST-shaped regression): GPU mp full pipeline: 47 -> 19.6 ms GPU mp vmap (production hot path): 18 -> 8.9 ms (49% faster) CPU vmap: ~unchanged Delta log-likelihood: 2.2e-3 absolute, far below chi2 noise floor Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Convolver.convolved_image_fromsouse_mixed_precision=Trueactually emits acomplex64FFT end-to-end. Previously the input was force-cast tojnp.float64(line 539) and only the result was narrowed at the end (line 581) — a net loss on consumer GPUs.ConvolverState.fft_kernel_c64so the per-call astype doesn't show up in CPU profiles.convolved_mapping_matrix_from: fp32 input cube + complex128 kernel multiply. Full fp32 in that path driftedfigure_of_meritby 1.9% on thedelaunay_mge.pyregression (K=780 source mesh) — pixelization NNLS and log-determinant need fp64.Settings.use_mixed_precisiondocstring to enumerate exactly which paths honor the flag and notes the GPU/CPU asymmetry.API Changes
None.
use_mixed_precision: bool = FalseonSettings,Convolver.convolved_image_from, andConvolver.convolved_mapping_matrix_fromkeeps its existing signature. Only the JAX FFT path's behaviour changes — the docstring is now accurate.Numerical impact
MGE imaging regression on RTX 2060 + i9-10885H (HST-shaped, 15k masked pixels, 40 linear Gaussians):
The headline result is GPU vmap — the production sampler hot path — going from 17.4 → 8.9 ms (49% faster). CPU single-JIT regresses ~17% because XLA's CPU FFT lowering doesn't compose well across the 40-FFT MGE-basis pipeline; CPU vmap is unchanged-to-slightly-faster, so production sampling on CPU is not affected. Δlog-likelihood is far below the natural χ² noise floor (σ ≈ √(2N) ≈ 175 for N=15k).
Why
convolved_mapping_matrix_fromis asymmetricThe pre-existing pattern of declaring
fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128and then never using it looks like a bug, but on this PR I verified that wiring it in causes a real numerical regression:autolens_workspace_test/scripts/jax_likelihood_functions/imaging/delaunay_mge.pyshifted log-likelihood by ~10 (1.9% relative) pastrtol=1e-4. Pixelization figure_of_merit accumulates fp32 round-off through the NNLS active-set and regularization log-determinant. This PR codifies that the multiply-upcast-to-complex128 is intentional, with a comment explaining it so a future reader does not "complete" the fix and break the same regression.Test plan
pytest test_autoarray/operators/test_convolver.py— 9 passedconvolver_mixed_precision.pyjax_assertion inautogalaxy_workspace_test(companion PR) — passes on GPU and CPU🤖 Generated with Claude Code