Skip to content

fix(nss): chunked algo.init follow-up to #1303 #1304

@Jammy2211

Description

@Jammy2211

Overview

PyAutoFit#1303 (feat(nss): chunk_size kwarg for inversion-heavy A100 likelihoods, merged c161235) chunked the per-iteration MCMC step's jax.vmap(num_delete) but didn't address a separate hardcoded jax.vmap(init_state_fn) inside blackjax.ns.nss.as_top_level_api's init_fn (blackjax/ns/nss.py:223-230). A100 retry on autolens_profiling NSS pixelization + delaunay × HST × fp64 (jobs 322605 + 322606) OOM at the same byte counts as before #1303 landed; the NSS configuration: INFO log line (which runs after algo.init) never appears, proving the crash is in algo.init not in algo.step.

This issue ships the missing init-side chunking. PyAutoFit gains a local build_chunked_nss_algorithm that replicates blackjax.nss.as_top_level_api (~30 lines) so we control both the chunked update_strategy (already done by #1303) and a chunked init_fn that swaps jax.vmap(init_state_fn) for jax.lax.map(init_state_fn, positions, batch_size=chunk_size). af.NSS._fit uses the local builder when chunk_size is set; chunk_size=None keeps using upstream blackjax.nss(...) bit-for-bit.

Together #1303 + this PR cover both NSS vmap paths and unblock NSS profiling on production lensing cells (SLaM source_pix[1/2] Delaunay / pixelization) at A100 80 GB.

Plan

  • New module autofit/non_linear/search/nest/nss/_chunked_nss.py exposing build_chunked_nss_algorithm(*, logprior_fn, loglikelihood_fn, num_inner_steps, num_delete, chunk_size). Internally calls make_chunked_update_strategy(chunk_size) from feat(nss): chunk_size kwarg for inversion-heavy A100 likelihoods #1303's _chunked_update.py for the step path, and swaps init_state_fn=jax.vmap(init_state_fn) for init_state_fn=lambda p: jax.lax.map(init_state_fn, p, batch_size=chunk_size) for the init path.
  • af.NSS._fit: when chunk_size is not None and chunk_size < max(n_live, num_delete), build algo via the local function. Otherwise continue calling _blackjax.nss(...). Default chunk_size=None preserves bit-identical un-chunked behaviour.
  • Unit tests (no JAX): the new factory returns a SamplingAlgorithm with the expected (init, step) callables; af.NSS plumbing tests continue to pass.
  • Manual A100 validation: resubmit autolens_profiling NSS pixelization + delaunay × HST × fp64 (chunk_size=16 set automatically by build_nss). Compare against Nautilus baselines: pixelization 46.5 ms/eval / 46 min (322603), delaunay 84.8 ms/eval / 45 min (322601), NSS MGE 1.6 ms/eval / 11 min (322590).
Detailed implementation plan

Affected Repositories

  • rhayes777/PyAutoFit (primary — af.NSS._fit switch + new _chunked_nss.py)
  • PyAutoLabs/autolens_profiling (consumer — no source change needed; build_nss already sets chunk_size=vmap_batch_for_cell(...) per autolens_profiling#43. After this lands + the HPC HPCPullPyAuto, the A100 resubmits should just work.)
  • handley-lab/blackjax (external — deferred. Long-term cleanup is upstreaming an init_batcher kwarg to as_top_level_api so the PyAutoFit-local replica can shrink to a thin forwarder.)

Work Classification

Library (PyAutoFit). No workspace changes; autolens_profiling consumer is already wired correctly.

Branch Survey

Repository Current Branch Dirty?
./PyAutoFit main clean
./autolens_profiling main clean (no edits needed)

Suggested branch: feature/nss-chunked-init
Worktree root: ~/Code/PyAutoLabs-wt/nss-chunked-init/ (created later by /start_library)

Implementation Steps

  1. Add autofit/non_linear/search/nest/nss/_chunked_nss.py (new file). Replicate the body of blackjax.ns.nss.as_top_level_api:

    from functools import partial
    import jax
    from blackjax import SamplingAlgorithm
    from blackjax.ns.adaptive import init as ns_init
    from blackjax.ns.base import init_state_strategy
    from blackjax.ns.nss import build_kernel, update_inner_kernel_params
    
    from ._chunked_update import make_chunked_update_strategy
    
    
    def build_chunked_nss_algorithm(
        *,
        logprior_fn,
        loglikelihood_fn,
        num_inner_steps,
        num_delete,
        chunk_size,
    ):
        init_state_fn = partial(
            init_state_strategy,
            logprior_fn=logprior_fn,
            loglikelihood_fn=loglikelihood_fn,
        )
    
        kernel = build_kernel(
            init_state_fn,
            num_inner_steps,
            num_delete,
            update_strategy=make_chunked_update_strategy(chunk_size),
        )
    
        def init_fn(position, rng_key=None):
            if chunk_size is None:
                init_batcher = jax.vmap(init_state_fn)
            else:
                init_batcher = lambda p: jax.lax.map(
                    init_state_fn, p, batch_size=chunk_size
                )
            return ns_init(
                position,
                init_state_fn=init_batcher,
                update_inner_kernel_params_fn=update_inner_kernel_params,
            )
    
        def step_fn(rng_key, state):
            return kernel(rng_key, state)
    
        return SamplingAlgorithm(init_fn, step_fn)
  2. Switch af.NSS._fit (PyAutoFit/autofit/non_linear/search/nest/nss/search.py): replace the current branch

    if self.chunk_size is not None and self.chunk_size < self.num_delete:
        nss_kwargs["update_strategy"] = make_chunked_update_strategy(self.chunk_size)
    algo = _blackjax.nss(**nss_kwargs)

    with:

    if self.chunk_size is not None and self.chunk_size < max(self.n_live, self.num_delete):
        from autofit.non_linear.search.nest.nss._chunked_nss import (
            build_chunked_nss_algorithm,
        )
        algo = build_chunked_nss_algorithm(
            logprior_fn=prior_logprob,
            loglikelihood_fn=log_likelihood,
            num_inner_steps=self.num_mcmc_steps,
            num_delete=self.num_delete,
            chunk_size=self.chunk_size,
        )
    else:
        algo = _blackjax.nss(**nss_kwargs)

    The max(n_live, num_delete) guard means we use the local replica whenever chunking would actually reduce memory in either of the two vmap sites. If chunk_size is None or already wider than both, fall through to upstream blackjax bit-for-bit.

  3. Unit tests (test_autofit/non_linear/search/nest/nss/test_search.py):

    • Extend test__chunked_update_strategy_factory style: import build_chunked_nss_algorithm and assert it returns a callable shape compatible with SamplingAlgorithm (has .init and .step attributes). No JAX execution — library policy keeps JAX-traced tests in workspace_test.
    • The existing kwarg-plumbing tests already cover af.NSS(chunk_size=...).
  4. Workspace-test follow-up (out of scope here, file as a separate prompt): bit-identical log_Z on a 5D Gaussian between _blackjax.nss(...).init(positions) and build_chunked_nss_algorithm(chunk_size=2).init(positions) on the same seed, with n_live=20.

  5. A100 validation (after merge + HPCPullPyAuto):

    • searches/nss/imaging/pixelization × hst × fp64 — was OOMing as 322605, expect ~46 min (≤ Nautilus's 322603 at 46.5 ms/eval)
    • searches/nss/imaging/delaunay × hst × fp64 — was OOMing as 322606, expect ~45 min (≤ Nautilus's 322601 at 84.8 ms/eval)

Key Files

  • PyAutoFit:autofit/non_linear/search/nest/nss/_chunked_nss.py (new)
  • PyAutoFit:autofit/non_linear/search/nest/nss/search.py (af.NSS._fit switch — 5 line diff)
  • PyAutoFit:test_autofit/non_linear/search/nest/nss/test_search.py (small additive test)

A100 evidence

Original Prompt

Click to expand starting prompt

af.NSS: chunked algo.init so n_live initial particles don't OOM

Context

Follow-up to PyAutoFit#1303 ("feat(nss): chunked vmap for inversion-heavy
A100 likelihoods"). That PR added a chunk_size kwarg that replaces
blackjax.ns.from_mcmc.update_with_mcmc_take_last's inner
jax.vmap(num_delete) with jax.lax.map(batch_size=chunk_size). The
PyAutoFit unit tests + a 5D Gaussian smoke confirmed bit-identical
log_Z between the unchunked and chunked paths.

A100 validation re-runs of the cells the first PR was supposed to
unblock (jobs 322605 NSS pixelization × HST × fp64, 322606 NSS delaunay
× HST × fp64) OOM at the same allocations as before (28.05 GB
pixelization, 27.67 GB delaunay):

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED:
  Out of memory while trying to allocate 28055330400 bytes.
  in PyAutoArray:autoarray/inversion/mappers/mapper_util.py:315
    mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib_out)

Decisive evidence the chunked-update path isn't the right seam: the
NSS configuration INFO log line (which lives after algo.init and
before the sampling loop) never appears. The OOM fires inside
algo.init(initial_samples), not inside the per-iteration algo.step
that the first PR fixed.

The actual root cause

blackjax.ns.nss.as_top_level_api constructs algo.init as:

def init_fn(position, rng_key=None):
    return init(
        position,
        init_state_fn=jax.vmap(init_state_fn),  # ← hardcoded, no kwarg seam
        update_inner_kernel_params_fn=update_inner_kernel_params_fn,
    )

(blackjax/ns/nss.py:223-230, handley-lab fork at SHA ef45acd2).

The jax.vmap(init_state_fn) is hardcoded inline — there is no
kwarg seam
to inject chunking through blackjax.nss(...). Initial
particle state for all n_live=150 particles is computed in one
parallel JAX call. With the
PyAutoArray:autoarray/inversion/mappers/mapper_util.py:315 scatter
allocating ~184 MB per particle (15,361 image pixels × 1,500 source
pixels × 8 bytes fp64) plus XLA scatter temp buffers, 150 ×
~184 MB ≈ 27.6 GB matches the observed OOM exactly (the ratio
28,055,330,400 / 184,332,000 = 152.2 is conspicuously close to
n_live=150).

The chunked update_strategy the first PR added only fires inside
algo.step, which runs after algo.init returns successfully. For
inversion-heavy lensing cells we never get there.

Desired fix

Replicate blackjax.nss.as_top_level_api locally in PyAutoFit (it's
~30 lines) so we control both:

  1. The chunked update_strategy (already done by PyAutoFit#1303 via
    make_chunked_update_strategy).
  2. The chunked init_fn (this PR).

Sketch (new module autofit/non_linear/search/nest/nss/_chunked_nss.py):

import jax
from functools import partial
from blackjax import SamplingAlgorithm
from blackjax.ns.adaptive import init as ns_init
from blackjax.ns.base import init_state_strategy
from blackjax.ns.nss import (
    build_kernel,
    update_inner_kernel_params,
)
from ._chunked_update import make_chunked_update_strategy


def build_chunked_nss_algorithm(
    *, logprior_fn, loglikelihood_fn,
    num_inner_steps, num_delete, chunk_size,
):
    """Local replica of ``blackjax.nss(as_top_level_api)`` with
    chunked init AND chunked update.

    ``chunk_size`` controls both the inner-vmap inside the per-
    iteration MCMC step (matches PyAutoFit#1303) and the n_live-wide
    vmap inside the algorithm's ``init``. When ``chunk_size`` is
    None or >= max(n_live, num_delete) both paths fall back to
    bit-identical upstream behaviour (plain ``jax.vmap``).
    """
    init_state_fn = partial(
        init_state_strategy,
        logprior_fn=logprior_fn,
        loglikelihood_fn=loglikelihood_fn,
    )

    kernel = build_kernel(
        init_state_fn,
        num_inner_steps,
        num_delete,
        update_strategy=make_chunked_update_strategy(chunk_size),
    )

    def init_fn(position, rng_key=None):
        if chunk_size is None:
            init_batcher = jax.vmap(init_state_fn)
        else:
            init_batcher = lambda p: jax.lax.map(
                init_state_fn, p, batch_size=chunk_size
            )
        return ns_init(
            position,
            init_state_fn=init_batcher,
            update_inner_kernel_params_fn=update_inner_kernel_params,
        )

    def step_fn(rng_key, state):
        return kernel(rng_key, state)

    return SamplingAlgorithm(init_fn, step_fn)

Then in af.NSS._fit, when chunk_size is set and chunking would
actually help (e.g. chunk_size < max(n_live, num_delete)), use
build_chunked_nss_algorithm instead of _blackjax.nss(...).
chunk_size=None keeps using blackjax.nss(...) unchanged for
bit-identical fallback.

Test plan

  • Unit test (test_autofit, no JAX): factory builds a SamplingAlgorithm
    with the expected (init, step) shape; kwarg plumbing on af.NSS
    threads chunk_size through.
  • Workspace-test (JAX-traced): bit-identical log_Z on a 5D Gaussian
    between _blackjax.nss(...).init (full vmap) and our
    build_chunked_nss_algorithm(chunk_size=2).init on the same seed,
    with n_live=20. (Same shape as the PyAutoFit#1303 smoke; that
    smoke only exercised the step seam, not the init seam.)
  • A100 end-to-end: resubmit autolens_profiling's
    searches/nss/imaging/{pixelization,delaunay} × hst × fp64;
    chunk_size=16 set automatically by
    autolens_profiling/searches/_samplers.py:build_nss. Confirm
    completion (was OOMing as 322605/606 after PyAutoFit#1303 landed,
    same allocations as 322602/604 before it).

Affected callers / interaction surface

  • af.NSS_fit switches from _blackjax.nss(...) to
    build_chunked_nss_algorithm(...) when chunk_size is set and
    smaller than the wider of n_live / num_delete.
  • autolens_profiling — no change needed. build_nss already
    sets chunk_size=vmap_batch_for_cell(...); the PyAutoFit-side
    change is transparent.
  • handley-lab/blackjax — still no patch needed for this fix.
    Long-term cleanup: file an upstream PR adding init_batcher (or
    similar) as a kwarg on blackjax.nss(as_top_level_api) so our
    local replica can shrink to a thin forwarder. Lower priority since
    the PyAutoFit shim works.

Why this matters

PyAutoFit#1303 was the right partial fix — per-iteration vmap chunking
is necessary, just not sufficient. Without this follow-up, NSS still
can't profile or run on the production lensing cells (SLaM
source_pix[1/2] Delaunay / pixelization phases) at A100 80 GB,
which were exactly the cells the original profile sweep was supposed
to compare against Nautilus.

The same A100 evidence as PyAutoFit#1301 applies; the next round
(322605 NSS pix, 322606 NSS delaunay) confirms the bug is in the
init path. Nautilus baselines for the comparison still apply:
pixelization 46.5 ms/eval / 46 min (322603), delaunay 84.8 ms/eval /
45 min (322601), NSS MGE 1.6 ms/eval / 11 min (322590).

Out of scope

  • Replacing slice-MCMC with HMC / NUTS for better mixing — separate
    upstream concern.
  • Multi-GPU sharding via jax.shard_map — single-GPU chunked init is
    the cheapest fix for the immediate gap.

Cross-references

  • PyAutoFit#1303 — first chunked-vmap PR, fixed the per-iteration path
  • PyAutoFit#1301 — original issue with chunked-vmap framing
  • autolens_profiling#43 — workspace consumer wiring of chunk_size
  • PyAutoPrompt/autoarray/delaunay_interpolator_pure_callback_vmap_memory.md
    — separate efficiency follow-up (already shown not to be the OOM cause)
  • blackjax/ns/nss.py:223-230 (the hardcoded jax.vmap(init_state_fn))
  • A100 evidence: 322605 (NSS pix), 322606 (NSS delaunay) — both OOM
    at the same byte counts as 322604 / 322602 from before
    PyAutoFit#1303

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions