Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion searches/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def _sampler_config_dict(
return {
"n_live": n_live,
"num_mcmc_steps": int(_NSS_DEFAULTS["num_mcmc_steps"]),
"num_delete": min(int(_NSS_DEFAULTS["num_delete"]), batch),
"num_delete": int(_NSS_DEFAULTS["num_delete"]),
"chunk_size": batch,
"termination": float(_NSS_DEFAULTS["termination"]),
"seed": int(_NSS_DEFAULTS["seed"]),
"jax_native": True,
Expand Down
24 changes: 14 additions & 10 deletions searches/_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,25 @@ def build_nss(
"for the NumPy-front profile."
)
n_live = n_live_for(dataset_class, model_type)
# NSS's ``num_delete`` plays the same role as Nautilus ``n_batch``: it
# controls how many likelihoods fire in parallel per outer iteration.
# Cap it at the per-cell vmap budget so heavy cells (delaunay, inversion-
# based) don't OOM the A100. Floor at the default so small cells still
# benefit from sane batching.
num_delete = min(
int(_NSS_DEFAULTS["num_delete"]),
vmap_batch_for_cell(dataset_class, model_type, instrument),
)
# Memory-budget plumbing:
# - ``num_delete`` stays at the sampler's preferred default (50 particles
# per outer iteration) so convergence isn't compromised.
# - ``chunk_size = vmap_batch_for_cell(...)`` caps the inner-vmap fan-out
# per the A100-probed budget. PyAutoFit#1303 swaps blackjax's internal
# ``jax.vmap(num_delete)`` for ``jax.lax.map(batch_size=chunk_size)``
# when ``chunk_size < num_delete``, so peak GPU memory becomes
# ``chunk_size × per_particle_state`` instead of ``num_delete × ...``.
# - ``chunk_size = None`` would also work on cells where the probe value
# is >= num_delete (e.g. point_source fallback), but passing the probe
# value explicitly keeps the JSON record honest about what was capped.
chunk_size = vmap_batch_for_cell(dataset_class, model_type, instrument)
return af.NSS(
name=config_name,
path_prefix=f"searches/{sampler}/{dataset_class}/{model_type}/{instrument}",
n_live=n_live,
num_mcmc_steps=int(_NSS_DEFAULTS["num_mcmc_steps"]),
num_delete=num_delete,
num_delete=int(_NSS_DEFAULTS["num_delete"]),
chunk_size=chunk_size,
termination=float(_NSS_DEFAULTS["termination"]),
seed=int(_NSS_DEFAULTS["seed"]),
)
Expand Down
Loading