diff --git a/autofit/non_linear/search/abstract_search.py b/autofit/non_linear/search/abstract_search.py index b2b72eb9b..192a7ae4d 100644 --- a/autofit/non_linear/search/abstract_search.py +++ b/autofit/non_linear/search/abstract_search.py @@ -858,14 +858,16 @@ def _fit_bypass_test_mode( # (grid search log_evidences, subhalo Bayesian model comparison, # scrape aggregator assertions) doesn't crash on None. SamplesPDF # reads log_evidence from samples_info. + samples_info = { + "total_iterations": 1, + "time": 0.0, + "log_evidence": log_likelihood, + } + samples_info.update(self._test_mode_samples_info()) samples = SamplesPDF( model=model, sample_list=sample_list, - samples_info={ - "total_iterations": 1, - "time": 0.0, - "log_evidence": log_likelihood, - }, + samples_info=samples_info, ) samples_summary = samples.summary() @@ -888,6 +890,19 @@ def _fit_bypass_test_mode( return result + def _test_mode_samples_info(self) -> dict: + """ + Sampler-specific keys to merge into ``samples_info`` when the + sampler is bypassed via ``PYAUTO_TEST_MODE=2`` or ``=3``. + + Override in subclasses to add the diagnostic keys that the real + run would populate (e.g. NUTS ESS, MCMC autocorrelations) so that + tutorial scripts and downstream code can access those keys + without ``KeyError``. Use NaN/0 placeholders — the bypass did not + actually sample. + """ + return {} + @staticmethod def _build_fake_samples(model, parameter_vector, log_likelihood): """ diff --git a/autofit/non_linear/search/mcmc/blackjax/nuts/search.py b/autofit/non_linear/search/mcmc/blackjax/nuts/search.py index 6bf9e9e57..986d99b26 100644 --- a/autofit/non_linear/search/mcmc/blackjax/nuts/search.py +++ b/autofit/non_linear/search/mcmc/blackjax/nuts/search.py @@ -378,6 +378,20 @@ def output_search_internal(self, search_internal): with open(self.backend_filename, "wb") as f: pickle.dump(search_internal, f) + def _test_mode_samples_info(self) -> dict: + return { + "num_warmup": int(self.num_warmup), + "num_samples": 0, + "num_chains": int(self.num_chains), + "ess_min": float("nan"), + "ess_per_param": [], + "mean_acceptance": float("nan"), + "n_divergent": 0, + "n_logl_evals": 0, + "total_walkers": int(self.num_chains), + "total_steps": 0, + } + def samples_info_from(self, search_internal=None): search_internal = search_internal if search_internal is not None else self.backend