feat: add BlackJAXNUTS first-class non-linear search#1256
Merged
Conversation
Adds `af.BlackJAXNUTS` alongside `Emcee`/`Zeus`/`Nautilus` etc., wrapping
BlackJAX's gradient-based No-U-Turn Sampler on top of PyAutoFit's
existing JAX-aware Fitness machinery. The class is named `BlackJAXNUTS`
(not `BlackJAX`) because BlackJAX is a library of samplers — NUTS is
one of several. The directory layout
`autofit/non_linear/search/mcmc/blackjax/nuts/search.py` leaves room
for `BlackJAXHMC`, `BlackJAXMALA`, etc. later.
The fit runs `blackjax.window_adaptation` (tunes step size + diagonal
inverse mass matrix) followed by NUTS sampling in a JIT'd
`jax.lax.scan`, broken into `iterations_per_full_update`-sized chunks
so `perform_update` flushes samples.csv and plots periodically (mirrors
the Emcee chunking pattern). The target log-density is built from
`Fitness.call` directly (the pure-JAX path; `call_wrap`/`__call__` are
intentionally bypassed because they convert to Python float and break
NUTS gradients).
`Analysis(use_jax=True)` is required — strict, with a clear error when
not. Sampling happens in physical parameter space; bounded priors
contribute -inf outside support.
Persistence is a pickle of {positions, log_likelihood_history, infos,
tuned_params, last_state_position, num_warmup, num_samples_completed,
num_chains} under `search_internal/search_internal.pickle`. Resume is
not implemented in v1, but the on-disk format leaves room for it.
`SamplesMCMC` plumbing is reused via `AbstractMCMC` inheritance — the
chain converts to the standard sample format (uniform weights — NUTS
gives unweighted draws). `AutoCorrelations` is populated from BlackJAX's
per-parameter ESS via the canonical identity τ_int = N / ESS, with
`previous_times` computed on the chain truncated by `check_size` (matches
emcee's slicing pattern). `samples_info` exposes NUTS-specific
diagnostics: ess_min, ess_per_param, mean_acceptance, n_divergent,
n_logl_evals (sum of leapfrog integration steps).
`blackjax>=1.2.0` is added to `optional-dependencies.optional`. The
import is lazy (inside `_fit`) so PyAutoFit installs without it.
Closes #1255.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This was referenced May 6, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
af.BlackJAXNUTS— BlackJAX gradient-based No-U-Turn Sampler as a first-class non-linear search alongsideEmcee/Zeus/Nautilusetc. InheritsAbstractMCMCso samples / paths / plotting / auto-correlation hooks slot in for free.blackjax.window_adaptationwarmup (step size + diagonal inverse mass matrix) → NUTS sampling in a JIT'djax.lax.scan, chunked byiterations_per_full_updatesoperform_updateflushes samples.csv and plots periodically (mirrors the Emcee chunking pattern).Fitness.calldirectly (pure-JAX path;call_wrap/__call__are intentionally bypassed because they convert to Python float and break NUTS gradients).Analysis(use_jax=True)is strictly required — clear error otherwise.num_chains=1); >1 raisesNotImplementedError. Resume is stubbed for later — the pickle layout leaves room.SamplesMCMCreused with uniform weights (NUTS gives unweighted draws).AutoCorrelationspopulated from BlackJAX per-parameter ESS via τ_int = N / ESS (canonical identity, robust Geyer-style estimator).samples_infoexposes NUTS diagnostics:ess_min,ess_per_param,mean_acceptance,n_divergent,n_logl_evals(sum of leapfrog integration steps).blackjax>=1.2.0added tooptional-dependencies.optional(lazy import in_fit).API Changes
Full API change list
Added
autofit.BlackJAXNUTS(name=None, path_prefix=None, unique_tag=None, num_warmup=500, num_samples=1000, num_chains=1, target_accept=0.8, max_num_doublings=10, seed=42, initializer=None, auto_correlation_settings=AutoCorrelationsSettings(check_for_convergence=False), iterations_per_quick_update=None, iterations_per_full_update=None, number_of_cores=1, silence=False, session=None, **kwargs)— gradient-based MCMC sampler. RequiresAnalysis(use_jax=True).autofit.non_linear.search.mcmc.blackjax.nuts.search(the directory layout reserves theblackjax/namespace for future BlackJAX samplers likeBlackJAXHMC).blackjax>=1.2.0.Removed
Renamed
Changed Signature
Changed Behaviour
Migration
pip install autofit[optional](which now pullsblackjax) and followautofit_workspace_test/scripts/searches/BlackJAXNUTS.py(added in the companion workspace PR).Test plan
pytest test_autofit/non_linear/search/mcmc— 9 passed (7 new BlackJAXNUTS tests + existing emcee/zeus, no regressions).pytest test_autofit/non_linear— 249 passed, no regressions.use_jax=True): recovers truth (centre, normalization, sigma) ≈ (50.04, 25.06, 9.997) within 1σ, ESS 256/500, 0 divergences.samples.csv,samples_info.json,samples_summary.json,model.results, corner plot, andsearch_internal/search_internal.pickleall written under the standard autofit paths.Closes #1255.
🤖 Generated with Claude Code