Skip to content

feat: add BlackJAXNUTS first-class non-linear search#1256

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/blackjax-nuts-search
May 6, 2026
Merged

feat: add BlackJAXNUTS first-class non-linear search#1256
Jammy2211 merged 1 commit into
mainfrom
feature/blackjax-nuts-search

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • Adds af.BlackJAXNUTS — BlackJAX gradient-based No-U-Turn Sampler as a first-class non-linear search alongside Emcee/Zeus/Nautilus etc. Inherits AbstractMCMC so samples / paths / plotting / auto-correlation hooks slot in for free.
  • Two-phase fit: blackjax.window_adaptation warmup (step size + diagonal inverse mass matrix) → NUTS sampling in a JIT'd jax.lax.scan, chunked by iterations_per_full_update so perform_update flushes samples.csv and plots periodically (mirrors the Emcee chunking pattern).
  • Target log-density built from Fitness.call directly (pure-JAX path; call_wrap/__call__ are intentionally bypassed because they convert to Python float and break NUTS gradients). Analysis(use_jax=True) is strictly required — clear error otherwise.
  • Sampling in physical parameter space; bounded priors contribute -inf outside support. Single chain in v1 (num_chains=1); >1 raises NotImplementedError. Resume is stubbed for later — the pickle layout leaves room.
  • SamplesMCMC reused with uniform weights (NUTS gives unweighted draws). AutoCorrelations populated from BlackJAX per-parameter ESS via τ_int = N / ESS (canonical identity, robust Geyer-style estimator). samples_info exposes NUTS diagnostics: ess_min, ess_per_param, mean_acceptance, n_divergent, n_logl_evals (sum of leapfrog integration steps).
  • blackjax>=1.2.0 added to optional-dependencies.optional (lazy import in _fit).

API Changes

Full API change list

Added

  • autofit.BlackJAXNUTS(name=None, path_prefix=None, unique_tag=None, num_warmup=500, num_samples=1000, num_chains=1, target_accept=0.8, max_num_doublings=10, seed=42, initializer=None, auto_correlation_settings=AutoCorrelationsSettings(check_for_convergence=False), iterations_per_quick_update=None, iterations_per_full_update=None, number_of_cores=1, silence=False, session=None, **kwargs) — gradient-based MCMC sampler. Requires Analysis(use_jax=True).
  • New module path: autofit.non_linear.search.mcmc.blackjax.nuts.search (the directory layout reserves the blackjax/ namespace for future BlackJAX samplers like BlackJAXHMC).
  • New optional dep: blackjax>=1.2.0.

Removed

  • None.

Renamed

  • None.

Changed Signature

  • None — fully additive.

Changed Behaviour

  • None — fully additive.

Migration

  • No migration needed for existing code. Users wanting to try the new sampler should pip install autofit[optional] (which now pulls blackjax) and follow autofit_workspace_test/scripts/searches/BlackJAXNUTS.py (added in the companion workspace PR).

Test plan

  • pytest test_autofit/non_linear/search/mcmc — 9 passed (7 new BlackJAXNUTS tests + existing emcee/zeus, no regressions).
  • pytest test_autofit/non_linear — 249 passed, no regressions.
  • End-to-end smoke on the 1D Gaussian (Analysis with use_jax=True): recovers truth (centre, normalization, sigma) ≈ (50.04, 25.06, 9.997) within 1σ, ESS 256/500, 0 divergences.
  • Disk output verified: samples.csv, samples_info.json, samples_summary.json, model.results, corner plot, and search_internal/search_internal.pickle all written under the standard autofit paths.

Closes #1255.

🤖 Generated with Claude Code

Adds `af.BlackJAXNUTS` alongside `Emcee`/`Zeus`/`Nautilus` etc., wrapping
BlackJAX's gradient-based No-U-Turn Sampler on top of PyAutoFit's
existing JAX-aware Fitness machinery. The class is named `BlackJAXNUTS`
(not `BlackJAX`) because BlackJAX is a library of samplers — NUTS is
one of several. The directory layout
`autofit/non_linear/search/mcmc/blackjax/nuts/search.py` leaves room
for `BlackJAXHMC`, `BlackJAXMALA`, etc. later.

The fit runs `blackjax.window_adaptation` (tunes step size + diagonal
inverse mass matrix) followed by NUTS sampling in a JIT'd
`jax.lax.scan`, broken into `iterations_per_full_update`-sized chunks
so `perform_update` flushes samples.csv and plots periodically (mirrors
the Emcee chunking pattern). The target log-density is built from
`Fitness.call` directly (the pure-JAX path; `call_wrap`/`__call__` are
intentionally bypassed because they convert to Python float and break
NUTS gradients).

`Analysis(use_jax=True)` is required — strict, with a clear error when
not. Sampling happens in physical parameter space; bounded priors
contribute -inf outside support.

Persistence is a pickle of {positions, log_likelihood_history, infos,
tuned_params, last_state_position, num_warmup, num_samples_completed,
num_chains} under `search_internal/search_internal.pickle`. Resume is
not implemented in v1, but the on-disk format leaves room for it.

`SamplesMCMC` plumbing is reused via `AbstractMCMC` inheritance — the
chain converts to the standard sample format (uniform weights — NUTS
gives unweighted draws). `AutoCorrelations` is populated from BlackJAX's
per-parameter ESS via the canonical identity τ_int = N / ESS, with
`previous_times` computed on the chain truncated by `check_size` (matches
emcee's slicing pattern). `samples_info` exposes NUTS-specific
diagnostics: ess_min, ess_per_param, mean_acceptance, n_divergent,
n_logl_evals (sum of leapfrog integration steps).

`blackjax>=1.2.0` is added to `optional-dependencies.optional`. The
import is lazy (inside `_fit`) so PyAutoFit installs without it.

Closes #1255.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: add BlackJAXNUTS first-class non-linear search

1 participant