Overview
Wire up checkpointing (Phase 2) and on-the-fly visualization (Phase 3) for af.NSS. Both phases share the same architectural hook — the outer loop in af.NSS._fit — so they ship together in one PyAutoFit PR. Closes Phases 2-3 of the nss_first_class_sampler roadmap.
Critical finding during the Phase 1 follow-up audit: the roadmap's framing was wrong. nss.ns.run_nested_sampling is not a one-shot JIT'd jax.lax.while_loop requiring an upstream yallup/nss PR. The JIT boundary is one_step (one outer iteration); the outer loop is plain Python. This means both checkpointing and visualization can be implemented entirely inside af.NSS._fit — no upstream dependency.
Plan
- Replace the single
nss.ns.run_nested_sampling(...) call inside NSS._fit with an inlined equivalent: build blackjax.nss algo + JIT'd one_step closure, then run the outer loop locally.
- Add a
checkpoint_interval: int = 100 kwarg. Every N outer iterations, pickle (state, dead, rng_key, iteration) to paths.search_internal_path / "nss_checkpoint.pkl". Standard JAX pytree → NumPy round-trip for serialisation.
- On
_fit entry, detect an existing checkpoint and resume from it (state + dead list + RNG); otherwise initialise fresh. The Phase 1 "resume not yet supported" warning goes away.
- Wire the existing
iterations_per_quick_update kwarg (Phase 1 accepted-but-no-op) to call analysis.visualize(paths=..., instance=..., during_analysis=True) on the current best live point between outer iterations.
- On successful completion, delete the checkpoint file — mirror Nautilus's
output_search_internal pattern.
- Add unit tests for
_save_checkpoint / _load_checkpoint pytree round-trip, checkpoint_interval kwarg acceptance, and resume detection (mocked to avoid running real nss).
- Add an end-to-end resume integration smoke under
autolens_workspace_developer/searches_minimal/ — run for N iterations, simulate interrupt, restart, confirm continuation to convergence with identical final state.
Detailed implementation plan
Affected Repositories
- PyAutoFit (primary — library)
- autolens_workspace_developer (secondary — resume smoke + viz smoke scripts)
Work Classification
Library (with dev-workspace smoke ridealong on the same branch)
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoFit |
main |
clean |
| ./autolens_workspace_developer |
main |
dirty (pre-existing local work — not ours, untouched) |
worktree_check_conflict (exit=0): no active task claims either repo.
Suggested branch: feature/nss-checkpointing-and-visualization
Worktree root: ~/Code/PyAutoLabs-wt/nss-checkpointing-and-visualization/ (created by /start_library)
Implementation Steps
-
Inline the outer loop in NSS._fit — PyAutoFit/autofit/non_linear/search/nest/nss/search.py. Replace the existing run_nested_sampling(...) call with:
import blackjax
from nss.ns import finalise, log_weights as _nss_log_weights, Results, safe_ess
algo = blackjax.nss(
logprior_fn=prior_logprob,
loglikelihood_fn=log_likelihood,
num_delete=self.num_delete,
num_inner_steps=self.num_mcmc_steps,
)
@jax.jit
def one_step(carry, xs):
state, k = carry
k, subk = jax.random.split(k, 2)
state, dead_point = algo.step(subk, state)
return (state, k), dead_point
checkpoint_path = self._nss_checkpoint_path
if checkpoint_path is not None and checkpoint_path.exists():
state, dead, run_key, iteration = _load_checkpoint(checkpoint_path)
self.logger.info("Resuming NSS from iteration %d", iteration)
else:
state = algo.init(initial_samples)
dead = []
iteration = 0
t_start = time.time()
while not state.integrator.logZ_live - state.integrator.logZ < self.termination:
(state, run_key), dead_info = one_step((state, run_key), None)
dead.append(dead_info)
iteration += 1
# Phase 2 — checkpoint
if (checkpoint_path is not None
and iteration % self.checkpoint_interval == 0):
_save_checkpoint(checkpoint_path, state, dead, run_key, iteration)
# Phase 3 — quick-update visualization
if (self.iterations_per_quick_update is not None
and iteration % self.iterations_per_quick_update == 0):
self._fire_quick_update(state, model, analysis)
wall_time = time.time() - t_start
final_state = finalise(state, dead)
# ... existing _NSSInternal repackaging
-
Add _save_checkpoint + _load_checkpoint module-level helpers — pickle-based round-trip via jax.tree_util.tree_map(np.asarray, ...) and inverse. Robust to interrupted writes via atomic rename (pickle.dump to *.tmp, then os.replace to final path).
-
Add _fire_quick_update(state, model, analysis) helper — extracts the highest-loglikelihood live particle from state.particles, maps to a ModelInstance via model.instance_from_vector, and calls analysis.visualize(paths=self.paths, instance=instance, during_analysis=True). Wraps in try/except to log + continue on visualization errors (don't kill a long fit because a plot misfired).
-
Add checkpoint_interval kwarg to NSS.__init__ (default 100). Update the docstring's "Stubbed / out of scope" section and remove the Phase 1 "resume not yet supported" + "quick-update visualization not yet wired" warnings.
-
Delete checkpoint on success — at the end of _fit, after _NSSInternal is built, if checkpoint_path is not None and checkpoint_path.exists(): checkpoint_path.unlink(). Mirror Nautilus's output_search_internal post-success cleanup.
-
Unit tests — PyAutoFit/test_autofit/non_linear/search/nest/nss/test_search.py (extend) and a new test_checkpoint.py:
_save_checkpoint / _load_checkpoint round-trip on a synthetic state pytree
NSS.__init__(checkpoint_interval=...) accepted; identifier_fields unchanged
- Resume detection — monkeypatch
Path.exists + the loader, confirm _fit enters the resume branch (without running nss)
- Atomic-write semantics — interrupt mid-
_save_checkpoint leaves .tmp file, not partial .pkl
-
Integration smoke — autolens_workspace_developer/searches_minimal/nss_checkpoint_resume.py:
- Run
af.NSS with checkpoint_interval=5, num_delete=10, n_live=40 (2-param Gaussian smoke model)
- Mid-run, save state and kill via
sys.exit inside a quick-update callback
- Restart with the same paths; assert resume happens (log line) and final state matches a single-shot run
-
Quick-update smoke — extend nss_first_class_gaussian.py (or create a sibling): set iterations_per_quick_update=3, run, assert paths.image_path contains PNGs written before final convergence.
Key Files
PyAutoFit/autofit/non_linear/search/nest/nss/search.py — main _fit rewrite + checkpoint helpers + viz hook
PyAutoFit/test_autofit/non_linear/search/nest/nss/test_search.py — extended unit tests
PyAutoFit/test_autofit/non_linear/search/nest/nss/test_checkpoint.py — new, focused on serialisation round-trip
autolens_workspace_developer/searches_minimal/nss_checkpoint_resume.py — new integration smoke
autolens_workspace_developer/searches_minimal/nss_first_class_gaussian.py — extend with quick-update assertions
Out of scope
- JIT persistent cache (separate follow-up — each cold + resumed fit pays 25-30 s while_loop compile)
- Install simplification — Phase 4 (
autofit/nss_install_simplification.md)
- Workspace tutorial scripts — Phase 5
iterations_per_full_update activation — still API-parity-only for nss (no separate full-update concept)
Risks / open questions
-
dead list memory growth — for a typical 5000-outer-iteration run, dead accumulates 5000 NSInfo pytrees in Python memory. Each pickle write re-serialises all of them — wasteful disk I/O if the list is long. Land the naïve pickle.dump(dead) first, measure, optimise to incremental-append later if needed. Open question for Phase 4 batch users running many sequential fits.
-
Checkpoint after success — Nautilus deletes its checkpoint.hdf5 after completion. Mirror that pattern: delete nss_checkpoint.pkl on _fit exit so the next fresh fit doesn't accidentally resume from a stale checkpoint. Open: should we leave the file with a .completed suffix for forensic inspection? Probably not — the samples.csv + samples_summary.json capture everything users actually need.
-
Resume reproducibility — verify that single-shot(50 iter) and (run 25, save, resume, run 25 more) produce byte-identical state.particles.position at iteration 50. Add as a parity test in the integration smoke.
-
Visualization cost — analysis.visualize on the autolens HST MGE problem takes seconds per call (model plots, residuals, etc.). At iterations_per_quick_update=10 with num_delete=50, that's a viz every 500 evals — fine. Document so users don't set iterations_per_quick_update=1 and tank performance.
Original Prompt
Click to expand starting prompt
Add checkpointing + on-the-fly visualization to af.NSS.
This is Phases 2 and 3 of z_features/nss_first_class_sampler.md. Both
phases share the same architectural hook — the outer loop in
af.NSS._fit — so they ship together in one PR.
Critical finding from Phase 1 follow-up audit
The z_features roadmap claimed nss.ns.run_nested_sampling is "a one-shot
JIT'd jax.lax.while_loop" requiring an upstream yallup/nss PR to add
a checkpoint hook. This is wrong. Inspecting the actual upstream
source (/home/jammy/venv/PyAuto/lib/python3.12/site-packages/nss/ns.py):
@jax.jit
def one_step(carry, xs):
state, k = carry
k, subk = jax.random.split(k, 2)
state, dead_point = algo.step(subk, state)
return (state, k), dead_point
dead = []
while not state.integrator.logZ_live - state.integrator.logZ < termination:
(state, rng_key), dead_info = one_step((state, rng_key), None)
dead.append(dead_info)
The JIT boundary is one_step (one outer iteration = num_delete deaths
processed in one batch). The outer while loop is plain Python. This
means both checkpointing and on-the-fly visualization can be implemented
entirely inside af.NSS._fit — no upstream PR needed.
[... full prompt as authored, truncated here for brevity in the GitHub-rendered issue. See PyAutoPrompt/issued/nss_checkpointing_and_visualization.md for the verbatim source.]
Overview
Wire up checkpointing (Phase 2) and on-the-fly visualization (Phase 3) for
af.NSS. Both phases share the same architectural hook — the outer loop inaf.NSS._fit— so they ship together in one PyAutoFit PR. Closes Phases 2-3 of thenss_first_class_samplerroadmap.Critical finding during the Phase 1 follow-up audit: the roadmap's framing was wrong.
nss.ns.run_nested_samplingis not a one-shot JIT'djax.lax.while_looprequiring an upstreamyallup/nssPR. The JIT boundary isone_step(one outer iteration); the outer loop is plain Python. This means both checkpointing and visualization can be implemented entirely insideaf.NSS._fit— no upstream dependency.Plan
nss.ns.run_nested_sampling(...)call insideNSS._fitwith an inlined equivalent: buildblackjax.nssalgo + JIT'done_stepclosure, then run the outer loop locally.checkpoint_interval: int = 100kwarg. Every N outer iterations, pickle(state, dead, rng_key, iteration)topaths.search_internal_path / "nss_checkpoint.pkl". Standard JAX pytree → NumPy round-trip for serialisation._fitentry, detect an existing checkpoint and resume from it (state + dead list + RNG); otherwise initialise fresh. The Phase 1 "resume not yet supported" warning goes away.iterations_per_quick_updatekwarg (Phase 1 accepted-but-no-op) to callanalysis.visualize(paths=..., instance=..., during_analysis=True)on the current best live point between outer iterations.output_search_internalpattern._save_checkpoint/_load_checkpointpytree round-trip,checkpoint_intervalkwarg acceptance, and resume detection (mocked to avoid running real nss).autolens_workspace_developer/searches_minimal/— run for N iterations, simulate interrupt, restart, confirm continuation to convergence with identical final state.Detailed implementation plan
Affected Repositories
Work Classification
Library (with dev-workspace smoke ridealong on the same branch)
Branch Survey
worktree_check_conflict(exit=0): no active task claims either repo.Suggested branch:
feature/nss-checkpointing-and-visualizationWorktree root:
~/Code/PyAutoLabs-wt/nss-checkpointing-and-visualization/(created by/start_library)Implementation Steps
Inline the outer loop in
NSS._fit—PyAutoFit/autofit/non_linear/search/nest/nss/search.py. Replace the existingrun_nested_sampling(...)call with:Add
_save_checkpoint+_load_checkpointmodule-level helpers — pickle-based round-trip viajax.tree_util.tree_map(np.asarray, ...)and inverse. Robust to interrupted writes via atomic rename (pickle.dumpto*.tmp, thenos.replaceto final path).Add
_fire_quick_update(state, model, analysis)helper — extracts the highest-loglikelihood live particle fromstate.particles, maps to aModelInstanceviamodel.instance_from_vector, and callsanalysis.visualize(paths=self.paths, instance=instance, during_analysis=True). Wraps in try/except to log + continue on visualization errors (don't kill a long fit because a plot misfired).Add
checkpoint_intervalkwarg toNSS.__init__(default 100). Update the docstring's "Stubbed / out of scope" section and remove the Phase 1 "resume not yet supported" + "quick-update visualization not yet wired" warnings.Delete checkpoint on success — at the end of
_fit, after_NSSInternalis built,if checkpoint_path is not None and checkpoint_path.exists(): checkpoint_path.unlink(). Mirror Nautilus'soutput_search_internalpost-success cleanup.Unit tests —
PyAutoFit/test_autofit/non_linear/search/nest/nss/test_search.py(extend) and a newtest_checkpoint.py:_save_checkpoint/_load_checkpointround-trip on a synthetic state pytreeNSS.__init__(checkpoint_interval=...)accepted; identifier_fields unchangedPath.exists+ the loader, confirm_fitenters the resume branch (without running nss)_save_checkpointleaves.tmpfile, not partial.pklIntegration smoke —
autolens_workspace_developer/searches_minimal/nss_checkpoint_resume.py:af.NSSwithcheckpoint_interval=5, num_delete=10, n_live=40(2-param Gaussian smoke model)sys.exitinside a quick-update callbackQuick-update smoke — extend
nss_first_class_gaussian.py(or create a sibling): setiterations_per_quick_update=3, run, assertpaths.image_pathcontains PNGs written before final convergence.Key Files
PyAutoFit/autofit/non_linear/search/nest/nss/search.py— main_fitrewrite + checkpoint helpers + viz hookPyAutoFit/test_autofit/non_linear/search/nest/nss/test_search.py— extended unit testsPyAutoFit/test_autofit/non_linear/search/nest/nss/test_checkpoint.py— new, focused on serialisation round-tripautolens_workspace_developer/searches_minimal/nss_checkpoint_resume.py— new integration smokeautolens_workspace_developer/searches_minimal/nss_first_class_gaussian.py— extend with quick-update assertionsOut of scope
autofit/nss_install_simplification.md)iterations_per_full_updateactivation — still API-parity-only for nss (no separate full-update concept)Risks / open questions
deadlist memory growth — for a typical 5000-outer-iteration run,deadaccumulates 5000NSInfopytrees in Python memory. Each pickle write re-serialises all of them — wasteful disk I/O if the list is long. Land the naïvepickle.dump(dead)first, measure, optimise to incremental-append later if needed. Open question for Phase 4 batch users running many sequential fits.Checkpoint after success — Nautilus deletes its
checkpoint.hdf5after completion. Mirror that pattern: deletenss_checkpoint.pklon_fitexit so the next fresh fit doesn't accidentally resume from a stale checkpoint. Open: should we leave the file with a.completedsuffix for forensic inspection? Probably not — thesamples.csv+samples_summary.jsoncapture everything users actually need.Resume reproducibility — verify that
single-shot(50 iter)and(run 25, save, resume, run 25 more)produce byte-identicalstate.particles.positionat iteration 50. Add as a parity test in the integration smoke.Visualization cost —
analysis.visualizeon the autolens HST MGE problem takes seconds per call (model plots, residuals, etc.). Atiterations_per_quick_update=10withnum_delete=50, that's a viz every 500 evals — fine. Document so users don't setiterations_per_quick_update=1and tank performance.Original Prompt
Click to expand starting prompt
Add checkpointing + on-the-fly visualization to
af.NSS.This is Phases 2 and 3 of
z_features/nss_first_class_sampler.md. Bothphases share the same architectural hook — the outer loop in
af.NSS._fit— so they ship together in one PR.Critical finding from Phase 1 follow-up audit
The z_features roadmap claimed
nss.ns.run_nested_samplingis "a one-shotJIT'd
jax.lax.while_loop" requiring an upstreamyallup/nssPR to adda checkpoint hook. This is wrong. Inspecting the actual upstream
source (
/home/jammy/venv/PyAuto/lib/python3.12/site-packages/nss/ns.py):The JIT boundary is
one_step(one outer iteration =num_deletedeathsprocessed in one batch). The outer
whileloop is plain Python. Thismeans both checkpointing and on-the-fly visualization can be implemented
entirely inside
af.NSS._fit— no upstream PR needed.[... full prompt as authored, truncated here for brevity in the GitHub-rendered issue. See
PyAutoPrompt/issued/nss_checkpointing_and_visualization.mdfor the verbatim source.]