Overview
Add af.BlackJAXNUTS as a first-class non-linear search alongside Emcee, Zeus, Nautilus etc. Wraps BlackJAX's NUTS (gradient-based MCMC) on top of PyAutoFit's existing JAX-aware Fitness machinery, slotted into the standard AbstractMCMC plumbing so samples / paths / plotting / auto-correlation hooks all work without bespoke handling.
The class is named BlackJAXNUTS (not BlackJAX) because BlackJAX is a library of samplers — NUTS is one of several. The directory layout (autofit/non_linear/search/mcmc/blackjax/nuts/search.py) leaves room for BlackJAXHMC, BlackJAXMALA, etc. later.
Plan
- New search class
BlackJAXNUTS(AbstractMCMC) runs blackjax.window_adaptation for warmup (tunes step size + diagonal inverse mass matrix), then a JIT'd jax.lax.scan NUTS chain, broken into iterations_per_full_update-sized chunks so perform_update() flushes samples.csv and plots periodically.
- Target log-density built from PyAutoFit's existing
Fitness wrapper (fom_is_log_likelihood=False) — the analysis must be constructed with use_jax=True, mirroring the existing Nautilus_jax.py pattern (enable_pytrees() + register_model(model)). Strict requirement; raises a clear error if use_jax=False.
- Sampling in physical parameter space; bounded priors (Uniform etc.) contribute
-inf outside support — same approach as the workspace searches_minimal/nss_grad.py.
- Single chain in v1 (
num_chains=1). No resume in v1, but the on-disk pickle layout leaves room for it.
- Output via the standard
SamplesMCMC (uniform weights, NUTS gives unweighted posterior). ESS, divergences, leapfrog eval count surfaced in samples_info.json.
AutoCorrelations populated by times = num_samples / ess_per_param (standard τ_int identity, robust BlackJAX Geyer-style ESS); previous_times from the same on samples[:-check_size].
pyproject.toml extras: append blackjax>=1.2.0 to optional-dependencies.optional.
- Tests + workspace examples + integration script in
autofit_workspace_test.
Detailed implementation plan
Affected Repositories
rhayes777/PyAutoFit (primary, library)
Jammy2211/autofit_workspace (workspace example)
Jammy2211/autofit_workspace_test (integration script)
Work Classification
Library + workspace — library work ships first, then workspace scripts follow.
Branch Survey (run via /plan_branches)
| Repository |
Current branch |
Dirty? |
./PyAutoFit |
main |
clean |
./autofit_workspace |
main |
dirty (~250 dataset files, unrelated regenerations) |
./autofit_workspace_test |
main |
dirty (9 dataset files, unrelated) |
No worktree conflicts on the three repos. Workspace dirty files stay on canonical/main; the worktree feature branch is born clean from origin/main.
Suggested branch: feature/blackjax-nuts-search
Worktree root: ~/Code/PyAutoLabs-wt/blackjax-nuts-search/ (created later by /start_library)
Files added / modified
PyAutoFit
autofit/non_linear/search/mcmc/blackjax/__init__.py (new)
autofit/non_linear/search/mcmc/blackjax/nuts/__init__.py (new)
autofit/non_linear/search/mcmc/blackjax/nuts/search.py (new) — BlackJAXNUTS(AbstractMCMC)
autofit/__init__.py — from .non_linear.search.mcmc.blackjax.nuts.search import BlackJAXNUTS
pyproject.toml — append "blackjax>=1.2.0" to optional-dependencies.optional
test_autofit/non_linear/search/mcmc/test_blackjax_nuts.py (new) — parameter validation, identifier fields, test-mode iteration reduction
autofit_workspace
scripts/searches/mcmc.py — append BlackJAXNUTS section with the use_jax=True requirement called out
autofit_workspace_test
scripts/searches/BlackJAXNUTS.py (new) — full integration fit on the 1D Gaussian, mirrors Nautilus_jax.py
BlackJAXNUTS constructor
class BlackJAXNUTS(AbstractMCMC):
__identifier_fields__ = (\"num_warmup\", \"num_samples\", \"num_chains\")
def __init__(
self,
name=None, path_prefix=None, unique_tag=None,
num_warmup: int = 500,
num_samples: int = 1000,
num_chains: int = 1,
target_accept: float = 0.8,
max_num_doublings: int = 10,
seed: int = 42,
initializer=None,
auto_correlation_settings=AutoCorrelationsSettings(check_for_convergence=False),
iterations_per_quick_update=None,
iterations_per_full_update=None,
number_of_cores: int = 1,
silence: bool = False,
session=None,
**kwargs,
):
apply_test_mode() reduces to num_warmup=20, num_samples=20.
_fit(model, analysis) flow
- Lazy
import blackjax, jax, jax.numpy as jnp.
- Validate
analysis.use_jax is True (or analysis._xp is jnp); raise a clear error otherwise.
- Build
Fitness(model, analysis, fom_is_log_likelihood=False, ...) → log-density target. Call enable_pytrees() + register_model(model) so instance_from_vector JIT-traces.
- Initial position via
self.initializer.samples_from_model(total_points=num_chains, ...).
blackjax.window_adaptation(blackjax.nuts, log_density, target_acceptance_rate=target_accept, max_num_doublings=max_num_doublings) for num_warmup steps.
- Build tuned NUTS kernel; sample in
iterations_per_full_update-sized chunks via jax.lax.scan. After each chunk, output_search_internal() + self.perform_update(...).
- Return
(search_internal_dict, fitness).
Persistence
search_internal/search_internal.pickle containing {positions, log_likelihood_history, infos, tuned_params, last_state}. Loaded via the backend property.
Samples conversion
samples_via_internal_from(model, search_internal):
- Extract positions (no thinning by default).
- Recompute log-prior via
model.log_prior_list_from(parameter_lists).
- Read log-likelihood from saved
log_likelihood_history.
- Build uniform-weight
Sample.from_lists → SamplesMCMC with auto-correlations populated.
samples_info_from
Returns the standard MCMC keys (check_size, required_length, change_threshold, total_walkers, total_steps, time) PLUS NUTS diagnostics (num_warmup, num_samples, num_chains, ess_min, ess_per_param, mean_acceptance, n_divergent, n_logl_evals).
Auto-correlation accuracy
times = num_samples / ess_per_param via blackjax.diagnostics.effective_sample_size(samples[None, ...]) — Geyer-style monotone variance estimator, single chain shape (1, num_samples, n_dim).
previous_times computed the same way on samples[:-check_size] (matches emcee's slicing pattern).
- Clamp
times to num_samples if any ess < 1 to avoid inf in convergence check.
- Convergence check defaults off (
check_for_convergence=False); user can opt in.
Tests
test__explicit_params — constructor round-trip.
test__test_mode_reduces_iterations — env-var path → reduced run shape.
test__identifier_fields — (num_warmup, num_samples, num_chains) distinct paths.
No "fits a real Gaussian" unit test — JAX install + pytree registration make that an integration concern.
Workspace integration test
autofit_workspace_test/scripts/searches/BlackJAXNUTS.py mirrors Nautilus_jax.py:
- Load 1D Gaussian dataset.
enable_pytrees() + register_model(model).
Analysis(use_jax=True).
af.BlackJAXNUTS(num_warmup=200, num_samples=500).
result = search.fit(model=model, analysis=analysis).
- Print recovered (centre, normalization, sigma) and
samples.converged.
Open questions resolved
- Sampling space: physical-parameter-space.
- Resume: skip in v1, layout supports later.
- Multi-chain: default 1.
- AutoCorrelations: ESS-synthesised with τ = N/ESS, with robustness clamp.
- JAX requirement: strict (clear error if use_jax=False).
Original Prompt
Click to expand starting prompt
Can you implement BlackJAX Nuts as a proper autofit sampler which is in the source code alongside emcee and the others. So treat it as a first class object which can do all the stuff like output results to hard disk, make their search scripts in the autofit_workspace and _test workspace. Think hard and make sure you do all the aspects we need to make it a search.
Overview
Add
af.BlackJAXNUTSas a first-class non-linear search alongsideEmcee,Zeus,Nautilusetc. Wraps BlackJAX's NUTS (gradient-based MCMC) on top of PyAutoFit's existing JAX-awareFitnessmachinery, slotted into the standardAbstractMCMCplumbing so samples / paths / plotting / auto-correlation hooks all work without bespoke handling.The class is named
BlackJAXNUTS(notBlackJAX) because BlackJAX is a library of samplers — NUTS is one of several. The directory layout (autofit/non_linear/search/mcmc/blackjax/nuts/search.py) leaves room forBlackJAXHMC,BlackJAXMALA, etc. later.Plan
BlackJAXNUTS(AbstractMCMC)runsblackjax.window_adaptationfor warmup (tunes step size + diagonal inverse mass matrix), then a JIT'djax.lax.scanNUTS chain, broken intoiterations_per_full_update-sized chunks soperform_update()flushes samples.csv and plots periodically.Fitnesswrapper (fom_is_log_likelihood=False) — the analysis must be constructed withuse_jax=True, mirroring the existingNautilus_jax.pypattern (enable_pytrees()+register_model(model)). Strict requirement; raises a clear error ifuse_jax=False.-infoutside support — same approach as the workspacesearches_minimal/nss_grad.py.num_chains=1). No resume in v1, but the on-disk pickle layout leaves room for it.SamplesMCMC(uniform weights, NUTS gives unweighted posterior). ESS, divergences, leapfrog eval count surfaced insamples_info.json.AutoCorrelationspopulated bytimes = num_samples / ess_per_param(standard τ_int identity, robust BlackJAX Geyer-style ESS);previous_timesfrom the same onsamples[:-check_size].pyproject.tomlextras: appendblackjax>=1.2.0tooptional-dependencies.optional.autofit_workspace_test.Detailed implementation plan
Affected Repositories
rhayes777/PyAutoFit(primary, library)Jammy2211/autofit_workspace(workspace example)Jammy2211/autofit_workspace_test(integration script)Work Classification
Library + workspace — library work ships first, then workspace scripts follow.
Branch Survey (run via /plan_branches)
./PyAutoFit./autofit_workspace./autofit_workspace_testNo worktree conflicts on the three repos. Workspace dirty files stay on canonical/main; the worktree feature branch is born clean from
origin/main.Suggested branch:
feature/blackjax-nuts-searchWorktree root:
~/Code/PyAutoLabs-wt/blackjax-nuts-search/(created later by/start_library)Files added / modified
PyAutoFit
autofit/non_linear/search/mcmc/blackjax/__init__.py(new)autofit/non_linear/search/mcmc/blackjax/nuts/__init__.py(new)autofit/non_linear/search/mcmc/blackjax/nuts/search.py(new) —BlackJAXNUTS(AbstractMCMC)autofit/__init__.py—from .non_linear.search.mcmc.blackjax.nuts.search import BlackJAXNUTSpyproject.toml— append"blackjax>=1.2.0"tooptional-dependencies.optionaltest_autofit/non_linear/search/mcmc/test_blackjax_nuts.py(new) — parameter validation, identifier fields, test-mode iteration reductionautofit_workspace
scripts/searches/mcmc.py— append BlackJAXNUTS section with theuse_jax=Truerequirement called outautofit_workspace_test
scripts/searches/BlackJAXNUTS.py(new) — full integration fit on the 1D Gaussian, mirrorsNautilus_jax.pyBlackJAXNUTSconstructorapply_test_mode()reduces tonum_warmup=20, num_samples=20._fit(model, analysis)flowimport blackjax,jax,jax.numpy as jnp.analysis.use_jaxis True (oranalysis._xp is jnp); raise a clear error otherwise.Fitness(model, analysis, fom_is_log_likelihood=False, ...)→ log-density target. Callenable_pytrees()+register_model(model)soinstance_from_vectorJIT-traces.self.initializer.samples_from_model(total_points=num_chains, ...).blackjax.window_adaptation(blackjax.nuts, log_density, target_acceptance_rate=target_accept, max_num_doublings=max_num_doublings)fornum_warmupsteps.iterations_per_full_update-sized chunks viajax.lax.scan. After each chunk,output_search_internal()+self.perform_update(...).(search_internal_dict, fitness).Persistence
search_internal/search_internal.picklecontaining{positions, log_likelihood_history, infos, tuned_params, last_state}. Loaded via thebackendproperty.Samples conversion
samples_via_internal_from(model, search_internal):model.log_prior_list_from(parameter_lists).log_likelihood_history.Sample.from_lists→SamplesMCMCwith auto-correlations populated.samples_info_fromReturns the standard MCMC keys (
check_size,required_length,change_threshold,total_walkers,total_steps,time) PLUS NUTS diagnostics (num_warmup,num_samples,num_chains,ess_min,ess_per_param,mean_acceptance,n_divergent,n_logl_evals).Auto-correlation accuracy
times = num_samples / ess_per_paramviablackjax.diagnostics.effective_sample_size(samples[None, ...])— Geyer-style monotone variance estimator, single chain shape(1, num_samples, n_dim).previous_timescomputed the same way onsamples[:-check_size](matches emcee's slicing pattern).timestonum_samplesif anyess < 1to avoidinfin convergence check.check_for_convergence=False); user can opt in.Tests
test__explicit_params— constructor round-trip.test__test_mode_reduces_iterations— env-var path → reduced run shape.test__identifier_fields—(num_warmup, num_samples, num_chains)distinct paths.No "fits a real Gaussian" unit test — JAX install + pytree registration make that an integration concern.
Workspace integration test
autofit_workspace_test/scripts/searches/BlackJAXNUTS.pymirrorsNautilus_jax.py:enable_pytrees()+register_model(model).Analysis(use_jax=True).af.BlackJAXNUTS(num_warmup=200, num_samples=500).result = search.fit(model=model, analysis=analysis).samples.converged.Open questions resolved
Original Prompt
Click to expand starting prompt