From 4498e3855ca7b30c514da45e806a48ba40dadab6 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 8 May 2026 15:56:03 +0100 Subject: [PATCH] fix: populate NUTS samples_info keys under test-mode bypass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `PYAUTO_TEST_MODE=2` skips the sampler in `_fit_bypass_test_mode`, producing a samples_info stub with only `total_iterations`, `time`, and `log_evidence`. Tutorial and downstream code that reads NUTS diagnostics (`ess_min`, `num_samples`, `mean_acceptance`, `n_divergent`, `n_logl_evals`) crashed with `KeyError`. Add a `_test_mode_samples_info()` hook on `AbstractSearch`, merged into the bypass `samples_info`. Override in `BlackJAXNUTS` to return the diagnostic keys with NaN/0 placeholders — bypass mode never ran the sampler, so honest empties propagate to the tutorial print. Fixes the `autofit_workspace/scripts/searches/mcmc.py` smoke test. Co-Authored-By: Claude Opus 4.7 (1M context) --- autofit/non_linear/search/abstract_search.py | 25 +++++++++++++++---- .../search/mcmc/blackjax/nuts/search.py | 14 +++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/autofit/non_linear/search/abstract_search.py b/autofit/non_linear/search/abstract_search.py index b2b72eb9b..192a7ae4d 100644 --- a/autofit/non_linear/search/abstract_search.py +++ b/autofit/non_linear/search/abstract_search.py @@ -858,14 +858,16 @@ def _fit_bypass_test_mode( # (grid search log_evidences, subhalo Bayesian model comparison, # scrape aggregator assertions) doesn't crash on None. SamplesPDF # reads log_evidence from samples_info. + samples_info = { + "total_iterations": 1, + "time": 0.0, + "log_evidence": log_likelihood, + } + samples_info.update(self._test_mode_samples_info()) samples = SamplesPDF( model=model, sample_list=sample_list, - samples_info={ - "total_iterations": 1, - "time": 0.0, - "log_evidence": log_likelihood, - }, + samples_info=samples_info, ) samples_summary = samples.summary() @@ -888,6 +890,19 @@ def _fit_bypass_test_mode( return result + def _test_mode_samples_info(self) -> dict: + """ + Sampler-specific keys to merge into ``samples_info`` when the + sampler is bypassed via ``PYAUTO_TEST_MODE=2`` or ``=3``. + + Override in subclasses to add the diagnostic keys that the real + run would populate (e.g. NUTS ESS, MCMC autocorrelations) so that + tutorial scripts and downstream code can access those keys + without ``KeyError``. Use NaN/0 placeholders — the bypass did not + actually sample. + """ + return {} + @staticmethod def _build_fake_samples(model, parameter_vector, log_likelihood): """ diff --git a/autofit/non_linear/search/mcmc/blackjax/nuts/search.py b/autofit/non_linear/search/mcmc/blackjax/nuts/search.py index 6bf9e9e57..986d99b26 100644 --- a/autofit/non_linear/search/mcmc/blackjax/nuts/search.py +++ b/autofit/non_linear/search/mcmc/blackjax/nuts/search.py @@ -378,6 +378,20 @@ def output_search_internal(self, search_internal): with open(self.backend_filename, "wb") as f: pickle.dump(search_internal, f) + def _test_mode_samples_info(self) -> dict: + return { + "num_warmup": int(self.num_warmup), + "num_samples": 0, + "num_chains": int(self.num_chains), + "ess_min": float("nan"), + "ess_per_param": [], + "mean_acceptance": float("nan"), + "n_divergent": 0, + "n_logl_evals": 0, + "total_walkers": int(self.num_chains), + "total_steps": 0, + } + def samples_info_from(self, search_internal=None): search_internal = search_internal if search_internal is not None else self.backend