Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions autofit/non_linear/search/abstract_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,14 +858,16 @@ def _fit_bypass_test_mode(
# (grid search log_evidences, subhalo Bayesian model comparison,
# scrape aggregator assertions) doesn't crash on None. SamplesPDF
# reads log_evidence from samples_info.
samples_info = {
"total_iterations": 1,
"time": 0.0,
"log_evidence": log_likelihood,
}
samples_info.update(self._test_mode_samples_info())
samples = SamplesPDF(
model=model,
sample_list=sample_list,
samples_info={
"total_iterations": 1,
"time": 0.0,
"log_evidence": log_likelihood,
},
samples_info=samples_info,
)

samples_summary = samples.summary()
Expand All @@ -888,6 +890,19 @@ def _fit_bypass_test_mode(

return result

def _test_mode_samples_info(self) -> dict:
"""
Sampler-specific keys to merge into ``samples_info`` when the
sampler is bypassed via ``PYAUTO_TEST_MODE=2`` or ``=3``.

Override in subclasses to add the diagnostic keys that the real
run would populate (e.g. NUTS ESS, MCMC autocorrelations) so that
tutorial scripts and downstream code can access those keys
without ``KeyError``. Use NaN/0 placeholders — the bypass did not
actually sample.
"""
return {}

@staticmethod
def _build_fake_samples(model, parameter_vector, log_likelihood):
"""
Expand Down
14 changes: 14 additions & 0 deletions autofit/non_linear/search/mcmc/blackjax/nuts/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,20 @@ def output_search_internal(self, search_internal):
with open(self.backend_filename, "wb") as f:
pickle.dump(search_internal, f)

def _test_mode_samples_info(self) -> dict:
return {
"num_warmup": int(self.num_warmup),
"num_samples": 0,
"num_chains": int(self.num_chains),
"ess_min": float("nan"),
"ess_per_param": [],
"mean_acceptance": float("nan"),
"n_divergent": 0,
"n_logl_evals": 0,
"total_walkers": int(self.num_chains),
"total_steps": 0,
}

def samples_info_from(self, search_internal=None):
search_internal = search_internal if search_internal is not None else self.backend

Expand Down
Loading