Skip to content

feat: add BlackJAXNUTS first-class non-linear search #1255

@Jammy2211

Description

@Jammy2211

Overview

Add af.BlackJAXNUTS as a first-class non-linear search alongside Emcee, Zeus, Nautilus etc. Wraps BlackJAX's NUTS (gradient-based MCMC) on top of PyAutoFit's existing JAX-aware Fitness machinery, slotted into the standard AbstractMCMC plumbing so samples / paths / plotting / auto-correlation hooks all work without bespoke handling.

The class is named BlackJAXNUTS (not BlackJAX) because BlackJAX is a library of samplers — NUTS is one of several. The directory layout (autofit/non_linear/search/mcmc/blackjax/nuts/search.py) leaves room for BlackJAXHMC, BlackJAXMALA, etc. later.

Plan

  • New search class BlackJAXNUTS(AbstractMCMC) runs blackjax.window_adaptation for warmup (tunes step size + diagonal inverse mass matrix), then a JIT'd jax.lax.scan NUTS chain, broken into iterations_per_full_update-sized chunks so perform_update() flushes samples.csv and plots periodically.
  • Target log-density built from PyAutoFit's existing Fitness wrapper (fom_is_log_likelihood=False) — the analysis must be constructed with use_jax=True, mirroring the existing Nautilus_jax.py pattern (enable_pytrees() + register_model(model)). Strict requirement; raises a clear error if use_jax=False.
  • Sampling in physical parameter space; bounded priors (Uniform etc.) contribute -inf outside support — same approach as the workspace searches_minimal/nss_grad.py.
  • Single chain in v1 (num_chains=1). No resume in v1, but the on-disk pickle layout leaves room for it.
  • Output via the standard SamplesMCMC (uniform weights, NUTS gives unweighted posterior). ESS, divergences, leapfrog eval count surfaced in samples_info.json.
  • AutoCorrelations populated by times = num_samples / ess_per_param (standard τ_int identity, robust BlackJAX Geyer-style ESS); previous_times from the same on samples[:-check_size].
  • pyproject.toml extras: append blackjax>=1.2.0 to optional-dependencies.optional.
  • Tests + workspace examples + integration script in autofit_workspace_test.
Detailed implementation plan

Affected Repositories

  • rhayes777/PyAutoFit (primary, library)
  • Jammy2211/autofit_workspace (workspace example)
  • Jammy2211/autofit_workspace_test (integration script)

Work Classification

Library + workspace — library work ships first, then workspace scripts follow.

Branch Survey (run via /plan_branches)

Repository Current branch Dirty?
./PyAutoFit main clean
./autofit_workspace main dirty (~250 dataset files, unrelated regenerations)
./autofit_workspace_test main dirty (9 dataset files, unrelated)

No worktree conflicts on the three repos. Workspace dirty files stay on canonical/main; the worktree feature branch is born clean from origin/main.

Suggested branch: feature/blackjax-nuts-search
Worktree root: ~/Code/PyAutoLabs-wt/blackjax-nuts-search/ (created later by /start_library)

Files added / modified

PyAutoFit

  • autofit/non_linear/search/mcmc/blackjax/__init__.py (new)
  • autofit/non_linear/search/mcmc/blackjax/nuts/__init__.py (new)
  • autofit/non_linear/search/mcmc/blackjax/nuts/search.py (new) — BlackJAXNUTS(AbstractMCMC)
  • autofit/__init__.pyfrom .non_linear.search.mcmc.blackjax.nuts.search import BlackJAXNUTS
  • pyproject.toml — append "blackjax>=1.2.0" to optional-dependencies.optional
  • test_autofit/non_linear/search/mcmc/test_blackjax_nuts.py (new) — parameter validation, identifier fields, test-mode iteration reduction

autofit_workspace

  • scripts/searches/mcmc.py — append BlackJAXNUTS section with the use_jax=True requirement called out

autofit_workspace_test

  • scripts/searches/BlackJAXNUTS.py (new) — full integration fit on the 1D Gaussian, mirrors Nautilus_jax.py

BlackJAXNUTS constructor

class BlackJAXNUTS(AbstractMCMC):
    __identifier_fields__ = (\"num_warmup\", \"num_samples\", \"num_chains\")

    def __init__(
        self,
        name=None, path_prefix=None, unique_tag=None,
        num_warmup: int = 500,
        num_samples: int = 1000,
        num_chains: int = 1,
        target_accept: float = 0.8,
        max_num_doublings: int = 10,
        seed: int = 42,
        initializer=None,
        auto_correlation_settings=AutoCorrelationsSettings(check_for_convergence=False),
        iterations_per_quick_update=None,
        iterations_per_full_update=None,
        number_of_cores: int = 1,
        silence: bool = False,
        session=None,
        **kwargs,
    ):

apply_test_mode() reduces to num_warmup=20, num_samples=20.

_fit(model, analysis) flow

  1. Lazy import blackjax, jax, jax.numpy as jnp.
  2. Validate analysis.use_jax is True (or analysis._xp is jnp); raise a clear error otherwise.
  3. Build Fitness(model, analysis, fom_is_log_likelihood=False, ...) → log-density target. Call enable_pytrees() + register_model(model) so instance_from_vector JIT-traces.
  4. Initial position via self.initializer.samples_from_model(total_points=num_chains, ...).
  5. blackjax.window_adaptation(blackjax.nuts, log_density, target_acceptance_rate=target_accept, max_num_doublings=max_num_doublings) for num_warmup steps.
  6. Build tuned NUTS kernel; sample in iterations_per_full_update-sized chunks via jax.lax.scan. After each chunk, output_search_internal() + self.perform_update(...).
  7. Return (search_internal_dict, fitness).

Persistence

search_internal/search_internal.pickle containing {positions, log_likelihood_history, infos, tuned_params, last_state}. Loaded via the backend property.

Samples conversion

samples_via_internal_from(model, search_internal):

  • Extract positions (no thinning by default).
  • Recompute log-prior via model.log_prior_list_from(parameter_lists).
  • Read log-likelihood from saved log_likelihood_history.
  • Build uniform-weight Sample.from_listsSamplesMCMC with auto-correlations populated.

samples_info_from

Returns the standard MCMC keys (check_size, required_length, change_threshold, total_walkers, total_steps, time) PLUS NUTS diagnostics (num_warmup, num_samples, num_chains, ess_min, ess_per_param, mean_acceptance, n_divergent, n_logl_evals).

Auto-correlation accuracy

  • times = num_samples / ess_per_param via blackjax.diagnostics.effective_sample_size(samples[None, ...]) — Geyer-style monotone variance estimator, single chain shape (1, num_samples, n_dim).
  • previous_times computed the same way on samples[:-check_size] (matches emcee's slicing pattern).
  • Clamp times to num_samples if any ess < 1 to avoid inf in convergence check.
  • Convergence check defaults off (check_for_convergence=False); user can opt in.

Tests

  • test__explicit_params — constructor round-trip.
  • test__test_mode_reduces_iterations — env-var path → reduced run shape.
  • test__identifier_fields(num_warmup, num_samples, num_chains) distinct paths.

No "fits a real Gaussian" unit test — JAX install + pytree registration make that an integration concern.

Workspace integration test

autofit_workspace_test/scripts/searches/BlackJAXNUTS.py mirrors Nautilus_jax.py:

  • Load 1D Gaussian dataset.
  • enable_pytrees() + register_model(model).
  • Analysis(use_jax=True).
  • af.BlackJAXNUTS(num_warmup=200, num_samples=500).
  • result = search.fit(model=model, analysis=analysis).
  • Print recovered (centre, normalization, sigma) and samples.converged.

Open questions resolved

  1. Sampling space: physical-parameter-space.
  2. Resume: skip in v1, layout supports later.
  3. Multi-chain: default 1.
  4. AutoCorrelations: ESS-synthesised with τ = N/ESS, with robustness clamp.
  5. JAX requirement: strict (clear error if use_jax=False).

Original Prompt

Click to expand starting prompt

Can you implement BlackJAX Nuts as a proper autofit sampler which is in the source code alongside emcee and the others. So treat it as a first class object which can do all the stuff like output results to hard disk, make their search scripts in the autofit_workspace and _test workspace. Think hard and make sure you do all the aspects we need to make it a search.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions