diff --git a/searches/_runner.py b/searches/_runner.py index 8eee5ad..2981a57 100644 --- a/searches/_runner.py +++ b/searches/_runner.py @@ -201,7 +201,8 @@ def _sampler_config_dict( return { "n_live": n_live, "num_mcmc_steps": int(_NSS_DEFAULTS["num_mcmc_steps"]), - "num_delete": min(int(_NSS_DEFAULTS["num_delete"]), batch), + "num_delete": int(_NSS_DEFAULTS["num_delete"]), + "chunk_size": batch, "termination": float(_NSS_DEFAULTS["termination"]), "seed": int(_NSS_DEFAULTS["seed"]), "jax_native": True, diff --git a/searches/_samplers.py b/searches/_samplers.py index 82a1e40..09ac2e0 100644 --- a/searches/_samplers.py +++ b/searches/_samplers.py @@ -171,21 +171,25 @@ def build_nss( "for the NumPy-front profile." ) n_live = n_live_for(dataset_class, model_type) - # NSS's ``num_delete`` plays the same role as Nautilus ``n_batch``: it - # controls how many likelihoods fire in parallel per outer iteration. - # Cap it at the per-cell vmap budget so heavy cells (delaunay, inversion- - # based) don't OOM the A100. Floor at the default so small cells still - # benefit from sane batching. - num_delete = min( - int(_NSS_DEFAULTS["num_delete"]), - vmap_batch_for_cell(dataset_class, model_type, instrument), - ) + # Memory-budget plumbing: + # - ``num_delete`` stays at the sampler's preferred default (50 particles + # per outer iteration) so convergence isn't compromised. + # - ``chunk_size = vmap_batch_for_cell(...)`` caps the inner-vmap fan-out + # per the A100-probed budget. PyAutoFit#1303 swaps blackjax's internal + # ``jax.vmap(num_delete)`` for ``jax.lax.map(batch_size=chunk_size)`` + # when ``chunk_size < num_delete``, so peak GPU memory becomes + # ``chunk_size × per_particle_state`` instead of ``num_delete × ...``. + # - ``chunk_size = None`` would also work on cells where the probe value + # is >= num_delete (e.g. point_source fallback), but passing the probe + # value explicitly keeps the JSON record honest about what was capped. + chunk_size = vmap_batch_for_cell(dataset_class, model_type, instrument) return af.NSS( name=config_name, path_prefix=f"searches/{sampler}/{dataset_class}/{model_type}/{instrument}", n_live=n_live, num_mcmc_steps=int(_NSS_DEFAULTS["num_mcmc_steps"]), - num_delete=num_delete, + num_delete=int(_NSS_DEFAULTS["num_delete"]), + chunk_size=chunk_size, termination=float(_NSS_DEFAULTS["termination"]), seed=int(_NSS_DEFAULTS["seed"]), )