Skip to content

feat(nss): chunk_size kwarg for inversion-heavy A100 likelihoods#1303

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/nss-chunked-vmap
May 29, 2026
Merged

feat(nss): chunk_size kwarg for inversion-heavy A100 likelihoods#1303
Jammy2211 merged 1 commit into
mainfrom
feature/nss-chunked-vmap

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

af.NSS (JAX-native nested slice sampler) is 7.5× faster per likelihood eval than af.Nautilus on lightweight parametric cells but cannot run at all on inversion-heavy cells (pixelization, Delaunay) at A100 80 GB scale. The blocker is a single un-chunked jax.vmap inside blackjax.ns.from_mcmc.update_with_mcmc_take_last that fans out num_delete particles in parallel; even num_delete=16 blows past 80 GB on a ~922 MB/replica likelihood.

This PR adds chunk_size: Optional[int] = None to af.NSS. When set and < num_delete, the inner vmap becomes jax.lax.map(batch_size=chunk_size) — same parallelism within a chunk, sequential across chunks. Peak GPU memory becomes chunk_size × per_particle_state. chunk_size=None is bit-identical to the previous behaviour; chunk_size >= num_delete is a no-op fallthrough.

The swap rides on blackjax.nss(update_strategy=...), already exposed in upstream. PyAutoFit ships a local make_chunked_update_strategy(chunk_size) factory — no blackjax patch required.

Refs #1301

API Changes

af.NSS.__init__ gains an optional chunk_size: Optional[int] = None kwarg. When unset, behaviour is bit-identical to the previous version. When set and below num_delete, the internal blackjax.nss(update_strategy=...) is swapped to a chunked variant that uses jax.lax.map(batch_size=chunk_size) instead of jax.vmap inside the inner MCMC kernel.

New module autofit.non_linear.search.nest.nss._chunked_update exposes make_chunked_update_strategy(chunk_size) for advanced callers who want the underlying factory directly.

See full details below.

Test Plan

  • pytest test_autofit/non_linear/search/nest/nss/ — 10/10 pass; new tests cover the chunk_size kwarg plumbing and the _chunked_update factory signature.
  • JAX-traced smoke (5D Gaussian, n_live=20, num_delete=4): chunk_size=None and chunk_size=2 produce bit-identical log_Z = -4.3208251152 on the same seed.
  • A100 follow-up (separate autolens_profiling PR): resubmit NSS pixelization + delaunay × HST × fp64 with chunk_size=16; confirm completion within ~3× of the Nautilus baseline (84.8 ms/eval, 45 min on delaunay).
Full API Changes (for automation & release notes)

Added

  • af.NSS(chunk_size: Optional[int] = None, ...) — new optional kwarg on af.NSS.__init__. Documented as a GPU-memory knob. Default None preserves the un-chunked behaviour.
  • autofit.non_linear.search.nest.nss._chunked_update.make_chunked_update_strategy(chunk_size: Optional[int]) -> Callable — factory that returns an update_strategy-compatible callable (matching blackjax.ns.from_mcmc.update_with_mcmc_take_last's three-arg signature).
  • af.NSS.chunk_size: Optional[int] — stored on the search instance.

Changed Behaviour

  • The NSS configuration INFO log line gains a chunk_size=<value> field (after num_delete).
  • When chunk_size is not None and chunk_size < num_delete, af.NSS._fit constructs blackjax.nss(..., update_strategy=make_chunked_update_strategy(chunk_size)) instead of the default. The default path (chunk_size=None or >= num_delete) is unchanged.

Migration

  • None. All existing code (no chunk_size kwarg) continues to work bit-identically. Opt in by passing chunk_size=<int> per the docstring.

Identifier hash

  • chunk_size is NOT added to __identifier_fields__. The kwarg is a memory-layout hint that does not affect the posterior, so two NSS runs with different chunk_size should share the same identifier and reuse cached results.

🤖 Generated with Claude Code

Adds an optional chunk_size knob to af.NSS that swaps blackjax's internal
jax.vmap(num_delete) fan-out for jax.lax.map(batch_size=chunk_size) inside
update_with_mcmc_take_last. Peak GPU memory becomes chunk_size ×
per_particle_state instead of num_delete × per_particle_state.

blackjax.nss(...) already exposes update_strategy as a kwarg, so the swap
is a clean drop-in via the new make_chunked_update_strategy(chunk_size)
factory — no blackjax changes required. chunk_size=None preserves
bit-identical un-chunked behaviour; chunk_size >= num_delete is also a
no-op fallthrough.

Unblocks NSS on PyAutoLens pixelization / Delaunay cells at A100 80 GB,
where prior attempts OOMed at ~28 GB on every num_delete=16 retry
(autolens_profiling jobs 322592/96/600/602/604; see #1301 for per-cell
evidence).

5D Gaussian smoke confirms bit-identical log_Z between chunk_size=None
and chunk_size=2 on the same seed.

Refs #1301

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

Workspace PR: PyAutoLabs/autolens_profiling#43

@Jammy2211 Jammy2211 merged commit c161235 into main May 29, 2026
7 checks passed
@Jammy2211 Jammy2211 deleted the feature/nss-chunked-vmap branch May 29, 2026 07:58
Jammy2211 added a commit that referenced this pull request May 29, 2026
PyAutoFit#1303 chunked the per-iteration MCMC step's jax.vmap(num_delete)
but left a separate hardcoded jax.vmap(init_state_fn) inside blackjax's
nss.as_top_level_api init_fn unchunked. A100 retry of the cells #1303
was supposed to unblock (autolens_profiling NSS pixelization + delaunay
× HST × fp64, jobs 322605 + 322606) OOM at the same byte counts as
before #1303 (28.05 GB pix, 27.67 GB delaunay); the "NSS configuration:"
log line never appears, confirming the crash is in algo.init not
algo.step.

This PR ships the missing init-side chunking. New module
autofit/non_linear/search/nest/nss/_chunked_nss.py exposes
build_chunked_nss_algorithm — a ~30-line local replica of
blackjax.nss.as_top_level_api that controls both vmap sites:

- step path: make_chunked_update_strategy from #1303 (unchanged)
- init path: jax.lax.map(init_state_fn, positions, batch_size=chunk_size)
  instead of jax.vmap(init_state_fn)

af.NSS._fit switches to the local builder when
chunk_size < max(n_live, num_delete). chunk_size=None or
chunk_size >= max(n_live, num_delete) keeps using upstream blackjax.nss(...)
bit-for-bit.

5D Gaussian smoke at n_live=20 (5 init chunks of 4): bit-identical
log_Z = -4.3208251152 between chunk_size=None and chunk_size=4. Both
configuration log lines print the chunk_size value, confirming both
code paths fired.

Refs #1301, #1303, #1304

Co-authored-by: Jammy2211 <JNightingale2211@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant