Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions autofit/non_linear/search/nest/nss/_chunked_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Chunked replacement for ``blackjax.ns.from_mcmc.update_with_mcmc_take_last``.

Upstream blackjax fans out ``num_delete`` particles through ``jax.vmap``
with no chunking:

sample_keys = jax.random.split(sample_key, num_delete)
return jax.vmap(mcmc_kernel)(sample_keys, start_state)

On inversion-heavy likelihoods (e.g. PyAutoLens pixelization / Delaunay
source models) the per-particle MCMC state plus scatter temp buffers
exceeds A100 80 GB even at ``num_delete=16``. See PyAutoFit#1301 for the
full diagnosis and per-cell evidence from ``autolens_profiling``.

``chunked_update_with_mcmc_take_last`` accepts a ``chunk_size`` kwarg and
swaps the vmap for ``jax.lax.map(..., batch_size=chunk_size)`` when
``chunk_size < num_delete`` — same vmap parallelism within a chunk, sequential
chunks across. Peak memory becomes ``chunk_size × per_particle_state``
instead of ``num_delete × per_particle_state``.

When ``chunk_size`` is None or ``>= num_delete`` the function is
bit-identical to upstream.

``blackjax.nss(...)`` already exposes ``update_strategy`` as a kwarg
(see ``blackjax/ns/nss.py:157``), so ``af.NSS._fit`` only needs to pass
this builder to opt in:

algo = _blackjax.nss(
...,
update_strategy=make_chunked_update_strategy(chunk_size),
)
"""

from __future__ import annotations

from functools import partial
from typing import Callable, Optional


def make_chunked_update_strategy(chunk_size: Optional[int]) -> Callable:
"""Return an ``update_strategy`` callable for ``blackjax.nss(...)``.

Signature matches ``blackjax.ns.from_mcmc.update_with_mcmc_take_last``
so it can be passed through the ``update_strategy=`` kwarg unmodified.

Parameters
----------
chunk_size
Number of particles to vmap-batch per chunk. When None or
``>= num_delete`` the chunked path is skipped and the function
falls through to a plain ``jax.vmap`` (matching upstream
behaviour bit-for-bit).
"""

def chunked_update_with_mcmc_take_last(
constrained_mcmc_step_fn,
num_mcmc_steps,
num_delete,
):
"""Drop-in for ``blackjax.ns.from_mcmc.update_with_mcmc_take_last``.

Identical to upstream except the inner
``jax.vmap(mcmc_kernel)(sample_keys, start_state)`` is replaced
with ``jax.lax.map(..., batch_size=chunk_size)`` when
``chunk_size`` is set and smaller than ``num_delete``.
"""
import jax
import jax.numpy as jnp

def update_function(rng_key, state, loglikelihood_0, **step_parameters):
choice_key, sample_key = jax.random.split(rng_key)
particles = state.particles

# Select start particles from survivors (verbatim from upstream).
weights = (particles.loglikelihood > loglikelihood_0).astype(jnp.float32)
weights = jnp.where(weights.sum() > 0.0, weights, jnp.ones_like(weights))
start_idx = jax.random.choice(
choice_key,
len(weights),
shape=(num_delete,),
p=weights / weights.sum(),
replace=True,
)
start_state = jax.tree.map(lambda x: x[start_idx], particles)

shared_mcmc_step_fn = partial(
constrained_mcmc_step_fn,
loglikelihood_0=loglikelihood_0,
**step_parameters,
)

def mcmc_kernel(rng_key, state):
keys = jax.random.split(rng_key, num_mcmc_steps)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(rng_key, state)
return new_state, info

final_state, infos = jax.lax.scan(body_fn, state, keys)
return final_state, infos

sample_keys = jax.random.split(sample_key, num_delete)

# Fall through to bit-identical upstream behaviour when the
# user hasn't asked for chunking, or when the requested chunk
# already covers every particle.
if chunk_size is None or chunk_size >= num_delete:
return jax.vmap(mcmc_kernel)(sample_keys, start_state)

# Chunked path: jax.lax.map(batch_size=k) vmaps within each
# chunk-of-k particles and loops across chunks.
return jax.lax.map(
lambda xs: mcmc_kernel(xs[0], xs[1]),
(sample_keys, start_state),
batch_size=chunk_size,
)

return update_function

return chunked_update_with_mcmc_take_last
36 changes: 33 additions & 3 deletions autofit/non_linear/search/nest/nss/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
n_live: int = 200,
num_mcmc_steps: int = 5,
num_delete: int = 50,
chunk_size: Optional[int] = None,
termination: float = -3.0,
checkpoint_interval: int = 100,
iterations_per_quick_update: Optional[int] = None,
Expand Down Expand Up @@ -195,6 +196,17 @@ def __init__(
Number of particles killed per outer iteration. Larger
``num_delete`` reduces JIT overhead per iteration at the cost of
slightly worse posterior coverage.
chunk_size
Optional GPU-memory knob. When set and ``< num_delete``, the
inner MCMC step vmap (which fans out ``num_delete`` particles
in parallel inside ``blackjax.ns.from_mcmc.update_with_mcmc_take_last``)
is replaced with ``jax.lax.map(..., batch_size=chunk_size)``.
Peak GPU memory becomes ``chunk_size × per_particle_state``
instead of ``num_delete × per_particle_state`` — required to
run NSS on inversion-heavy likelihoods (PyAutoLens pixelization
/ Delaunay) at A100 80 GB scale. Default ``None`` preserves the
upstream un-chunked behaviour (and is the right choice on CPU
or whenever ``num_delete`` already fits the device).
termination
Convergence criterion. The fit stops when
``logZ_live - logZ < termination``. Default ``-3.0`` corresponds
Expand Down Expand Up @@ -262,6 +274,7 @@ def __init__(
self.n_live = n_live
self.num_mcmc_steps = num_mcmc_steps
self.num_delete = num_delete
self.chunk_size = chunk_size
self.termination = termination
self.checkpoint_interval = checkpoint_interval
self.seed = seed
Expand Down Expand Up @@ -342,12 +355,28 @@ def prior_logprob(params):
]
)

algo = _blackjax.nss(
nss_kwargs = dict(
logprior_fn=prior_logprob,
loglikelihood_fn=log_likelihood,
num_delete=self.num_delete,
num_inner_steps=self.num_mcmc_steps,
)
# When ``chunk_size`` is set and below ``num_delete``, swap blackjax's
# default ``update_with_mcmc_take_last`` for a chunked variant whose
# inner vmap becomes ``jax.lax.map(batch_size=chunk_size)`` — see
# ``_chunked_update.py`` for the rationale and PyAutoFit#1301 for
# per-A100-cell evidence. ``chunk_size=None`` / ``>= num_delete`` are
# no-ops (the chunked builder still falls back to ``jax.vmap``).
if self.chunk_size is not None and self.chunk_size < self.num_delete:
from autofit.non_linear.search.nest.nss._chunked_update import (
make_chunked_update_strategy,
)

nss_kwargs["update_strategy"] = make_chunked_update_strategy(
self.chunk_size
)

algo = _blackjax.nss(**nss_kwargs)

@jax.jit
def one_step(carry, _):
Expand All @@ -370,11 +399,12 @@ def one_step(carry, _):
iteration = 0
self.logger.info(
"NSS configuration: n_live=%d, num_mcmc_steps=%d, num_delete=%d, "
"termination=%s, ndim=%d, checkpoint_interval=%d. JIT compile on "
"first iteration may take 25-30 s.",
"chunk_size=%s, termination=%s, ndim=%d, checkpoint_interval=%d. "
"JIT compile on first iteration may take 25-30 s.",
self.n_live,
self.num_mcmc_steps,
self.num_delete,
self.chunk_size,
self.termination,
ndim,
self.checkpoint_interval,
Expand Down
28 changes: 28 additions & 0 deletions test_autofit/non_linear/search/nest/nss/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,52 @@ def test__explicit_params():
n_live=500,
num_mcmc_steps=10,
num_delete=20,
chunk_size=4,
termination=-2.0,
seed=7,
)

assert search.n_live == 500
assert search.num_mcmc_steps == 10
assert search.num_delete == 20
assert search.chunk_size == 4
assert search.termination == -2.0
assert search.seed == 7

default = af.NSS()
assert default.n_live == 200
assert default.num_mcmc_steps == 5
assert default.num_delete == 50
assert default.chunk_size is None
assert default.termination == -3.0
assert default.seed == 42


def test__chunked_update_strategy_factory():
"""``make_chunked_update_strategy`` returns a callable with the same
signature as blackjax's ``update_with_mcmc_take_last`` regardless of
whether ``chunk_size`` is set. This lets ``af.NSS._fit`` drop it into
``blackjax.nss(update_strategy=...)`` without further branching.
"""
from autofit.non_linear.search.nest.nss._chunked_update import (
make_chunked_update_strategy,
)

strategy_none = make_chunked_update_strategy(None)
strategy_chunked = make_chunked_update_strategy(4)
# Both are callables with the upstream three-arg signature
# (constrained_mcmc_step_fn, num_mcmc_steps, num_delete).
import inspect

for strategy in (strategy_none, strategy_chunked):
params = list(inspect.signature(strategy).parameters)
assert params == [
"constrained_mcmc_step_fn",
"num_mcmc_steps",
"num_delete",
]


def test__identifier_fields():
search = af.NSS()
for field in ("n_live", "num_mcmc_steps", "num_delete", "termination", "seed"):
Expand Down
Loading