Skip to content

Make use_mixed_precision actually emit fp32 FFT for light profiles#302

Merged
Jammy2211 merged 1 commit intomainfrom
feature/fft-mixed-precision-fix
May 8, 2026
Merged

Make use_mixed_precision actually emit fp32 FFT for light profiles#302
Jammy2211 merged 1 commit intomainfrom
feature/fft-mixed-precision-fix

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • Fixes Convolver.convolved_image_from so use_mixed_precision=True actually emits a complex64 FFT end-to-end. Previously the input was force-cast to jnp.float64 (line 539) and only the result was narrowed at the end (line 581) — a net loss on consumer GPUs.
  • Pre-caches the complex64 kernel on ConvolverState.fft_kernel_c64 so the per-call astype doesn't show up in CPU profiles.
  • Documents the intentional asymmetry in convolved_mapping_matrix_from: fp32 input cube + complex128 kernel multiply. Full fp32 in that path drifted figure_of_merit by 1.9% on the delaunay_mge.py regression (K=780 source mesh) — pixelization NNLS and log-determinant need fp64.
  • Tightens the Settings.use_mixed_precision docstring to enumerate exactly which paths honor the flag and notes the GPU/CPU asymmetry.

API Changes

None. use_mixed_precision: bool = False on Settings, Convolver.convolved_image_from, and Convolver.convolved_mapping_matrix_from keeps its existing signature. Only the JAX FFT path's behaviour changes — the docstring is now accurate.

Numerical impact

MGE imaging regression on RTX 2060 + i9-10885H (HST-shaped, 15k masked pixels, 40 linear Gaussians):

Config Full pipeline (single JIT) vmap per call Δlog-likelihood
GPU fp64 21.8 ms 17.4 ms
GPU mp 19.6 ms (10% faster) 8.9 ms (49% faster) 2.2e-3
CPU fp64 218.9 ms 141.5 ms
CPU mp 255.8 ms (17% slower) 138.8 ms (2% faster) 6.5e-5

The headline result is GPU vmap — the production sampler hot path — going from 17.4 → 8.9 ms (49% faster). CPU single-JIT regresses ~17% because XLA's CPU FFT lowering doesn't compose well across the 40-FFT MGE-basis pipeline; CPU vmap is unchanged-to-slightly-faster, so production sampling on CPU is not affected. Δlog-likelihood is far below the natural χ² noise floor (σ ≈ √(2N) ≈ 175 for N=15k).

Why convolved_mapping_matrix_from is asymmetric

The pre-existing pattern of declaring fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128 and then never using it looks like a bug, but on this PR I verified that wiring it in causes a real numerical regression: autolens_workspace_test/scripts/jax_likelihood_functions/imaging/delaunay_mge.py shifted log-likelihood by ~10 (1.9% relative) past rtol=1e-4. Pixelization figure_of_merit accumulates fp32 round-off through the NNLS active-set and regularization log-determinant. This PR codifies that the multiply-upcast-to-complex128 is intentional, with a comment explaining it so a future reader does not "complete" the fix and break the same regression.

Test plan

  • pytest test_autoarray/operators/test_convolver.py — 9 passed
  • New convolver_mixed_precision.py jax_assertion in autogalaxy_workspace_test (companion PR) — passes on GPU and CPU
  • Full JAX likelihood-function integration suite — 23/23 pass (autolens + autogalaxy, imaging + interferometer, MGE + rectangular + Delaunay)
  • mge.py profiling baseline numbers above

🤖 Generated with Claude Code

The flag previously forced fp64 FFT internally in convolved_image_from and
only downcast at the end -- a net loss on consumer GPUs (mp full pipeline
27% slower than fp64 on RTX 2060). The light-profile FFT path now runs
end-to-end complex64, with the kernel pre-cached on
ConvolverState.fft_kernel_c64 to keep the per-call astype out of CPU
profiles.

convolved_mapping_matrix_from intentionally keeps the fp64 kernel multiply
to preserve pixelization figure_of_merit precision: full fp32 in that path
caused 1.9% relative drift on the autolens_workspace_test delaunay_mge
regression (K=780 source mesh). The fp32 input cube and forward rfft2 are
kept for the cheaper scatter and FFT, but the multiply upcasts back to
complex128 and the irfft2 returns fp64.

Empirical impact (RTX 2060 + i9-10885H, mge.py HST-shaped regression):
  GPU mp full pipeline: 47 -> 19.6 ms
  GPU mp vmap (production hot path): 18 -> 8.9 ms (49% faster)
  CPU vmap: ~unchanged
  Delta log-likelihood: 2.2e-3 absolute, far below chi2 noise floor

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@Jammy2211 Jammy2211 merged commit 9b4df25 into main May 8, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/fft-mixed-precision-fix branch May 8, 2026 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant