feat(nss): chunk_size kwarg for inversion-heavy A100 likelihoods#1303
Merged
Conversation
Adds an optional chunk_size knob to af.NSS that swaps blackjax's internal jax.vmap(num_delete) fan-out for jax.lax.map(batch_size=chunk_size) inside update_with_mcmc_take_last. Peak GPU memory becomes chunk_size × per_particle_state instead of num_delete × per_particle_state. blackjax.nss(...) already exposes update_strategy as a kwarg, so the swap is a clean drop-in via the new make_chunked_update_strategy(chunk_size) factory — no blackjax changes required. chunk_size=None preserves bit-identical un-chunked behaviour; chunk_size >= num_delete is also a no-op fallthrough. Unblocks NSS on PyAutoLens pixelization / Delaunay cells at A100 80 GB, where prior attempts OOMed at ~28 GB on every num_delete=16 retry (autolens_profiling jobs 322592/96/600/602/604; see #1301 for per-cell evidence). 5D Gaussian smoke confirms bit-identical log_Z between chunk_size=None and chunk_size=2 on the same seed. Refs #1301 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This was referenced May 29, 2026
Collaborator
Author
|
Workspace PR: PyAutoLabs/autolens_profiling#43 |
This was referenced May 29, 2026
Jammy2211
added a commit
that referenced
this pull request
May 29, 2026
PyAutoFit#1303 chunked the per-iteration MCMC step's jax.vmap(num_delete) but left a separate hardcoded jax.vmap(init_state_fn) inside blackjax's nss.as_top_level_api init_fn unchunked. A100 retry of the cells #1303 was supposed to unblock (autolens_profiling NSS pixelization + delaunay × HST × fp64, jobs 322605 + 322606) OOM at the same byte counts as before #1303 (28.05 GB pix, 27.67 GB delaunay); the "NSS configuration:" log line never appears, confirming the crash is in algo.init not algo.step. This PR ships the missing init-side chunking. New module autofit/non_linear/search/nest/nss/_chunked_nss.py exposes build_chunked_nss_algorithm — a ~30-line local replica of blackjax.nss.as_top_level_api that controls both vmap sites: - step path: make_chunked_update_strategy from #1303 (unchanged) - init path: jax.lax.map(init_state_fn, positions, batch_size=chunk_size) instead of jax.vmap(init_state_fn) af.NSS._fit switches to the local builder when chunk_size < max(n_live, num_delete). chunk_size=None or chunk_size >= max(n_live, num_delete) keeps using upstream blackjax.nss(...) bit-for-bit. 5D Gaussian smoke at n_live=20 (5 init chunks of 4): bit-identical log_Z = -4.3208251152 between chunk_size=None and chunk_size=4. Both configuration log lines print the chunk_size value, confirming both code paths fired. Refs #1301, #1303, #1304 Co-authored-by: Jammy2211 <JNightingale2211@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
af.NSS(JAX-native nested slice sampler) is 7.5× faster per likelihood eval thanaf.Nautiluson lightweight parametric cells but cannot run at all on inversion-heavy cells (pixelization, Delaunay) at A100 80 GB scale. The blocker is a single un-chunkedjax.vmapinsideblackjax.ns.from_mcmc.update_with_mcmc_take_lastthat fans outnum_deleteparticles in parallel; evennum_delete=16blows past 80 GB on a ~922 MB/replica likelihood.This PR adds
chunk_size: Optional[int] = Nonetoaf.NSS. When set and< num_delete, the inner vmap becomesjax.lax.map(batch_size=chunk_size)— same parallelism within a chunk, sequential across chunks. Peak GPU memory becomeschunk_size × per_particle_state.chunk_size=Noneis bit-identical to the previous behaviour;chunk_size >= num_deleteis a no-op fallthrough.The swap rides on
blackjax.nss(update_strategy=...), already exposed in upstream. PyAutoFit ships a localmake_chunked_update_strategy(chunk_size)factory — no blackjax patch required.Refs #1301
API Changes
af.NSS.__init__gains an optionalchunk_size: Optional[int] = Nonekwarg. When unset, behaviour is bit-identical to the previous version. When set and belownum_delete, the internalblackjax.nss(update_strategy=...)is swapped to a chunked variant that usesjax.lax.map(batch_size=chunk_size)instead ofjax.vmapinside the inner MCMC kernel.New module
autofit.non_linear.search.nest.nss._chunked_updateexposesmake_chunked_update_strategy(chunk_size)for advanced callers who want the underlying factory directly.See full details below.
Test Plan
pytest test_autofit/non_linear/search/nest/nss/— 10/10 pass; new tests cover thechunk_sizekwarg plumbing and the_chunked_updatefactory signature.n_live=20, num_delete=4):chunk_size=Noneandchunk_size=2produce bit-identicallog_Z = -4.3208251152on the same seed.autolens_profilingPR): resubmit NSS pixelization + delaunay × HST × fp64 withchunk_size=16; confirm completion within ~3× of the Nautilus baseline (84.8 ms/eval, 45 min on delaunay).Full API Changes (for automation & release notes)
Added
af.NSS(chunk_size: Optional[int] = None, ...)— new optional kwarg onaf.NSS.__init__. Documented as a GPU-memory knob. DefaultNonepreserves the un-chunked behaviour.autofit.non_linear.search.nest.nss._chunked_update.make_chunked_update_strategy(chunk_size: Optional[int]) -> Callable— factory that returns anupdate_strategy-compatible callable (matchingblackjax.ns.from_mcmc.update_with_mcmc_take_last's three-arg signature).af.NSS.chunk_size: Optional[int]— stored on the search instance.Changed Behaviour
chunk_size=<value>field (afternum_delete).chunk_size is not None and chunk_size < num_delete,af.NSS._fitconstructsblackjax.nss(..., update_strategy=make_chunked_update_strategy(chunk_size))instead of the default. The default path (chunk_size=Noneor>= num_delete) is unchanged.Migration
chunk_sizekwarg) continues to work bit-identically. Opt in by passingchunk_size=<int>per the docstring.Identifier hash
chunk_sizeis NOT added to__identifier_fields__. The kwarg is a memory-layout hint that does not affect the posterior, so two NSS runs with differentchunk_sizeshould share the same identifier and reuse cached results.🤖 Generated with Claude Code