diff --git a/scripts/searches/mcmc.py b/scripts/searches/mcmc.py index acf9501c..611a56ed 100644 --- a/scripts/searches/mcmc.py +++ b/scripts/searches/mcmc.py @@ -2,15 +2,17 @@ Searches: MCMC ============== -This example illustrates how to use the MCMC ensemble sampler algorithms supported by **PyAutoFit**: +This example illustrates how to use the MCMC sampler algorithms supported by **PyAutoFit**: - `Emcee`: The emcee ensemble sampler. - `Zeus`: The zeus ensemble slice sampler. + - `BlackJAXNUTS`: BlackJAX's No-U-Turn Sampler — gradient-based MCMC requiring `use_jax=True`. Relevant links: - - Emcee: https://emcee.readthedocs.io/en/stable/ - - Zeus: https://zeus-mcmc.readthedocs.io/en/latest/ + - Emcee: https://emcee.readthedocs.io/en/stable/ + - Zeus: https://zeus-mcmc.readthedocs.io/en/latest/ + - BlackJAX: https://github.com/blackjax-devs/blackjax __Contents__ @@ -20,6 +22,7 @@ - **Model + Analysis**: Setting up the model and analysis shared by every search below. - **Search: Emcee**: Configuring and running the Emcee sampler. - **Search: Zeus**: Configuring and running the Zeus sampler. +- **Search: BlackJAXNUTS**: Configuring and running BlackJAX's NUTS sampler (requires `use_jax=True`). - **Search Internal**: Accessing the internal sampler for advanced use (shown once for Emcee). """ @@ -224,3 +227,98 @@ plt.ylabel("Profile normalization") plt.show() plt.close() + +""" +__Search: BlackJAXNUTS__ + +`BlackJAXNUTS` is the No-U-Turn Sampler from BlackJAX — a gradient-based MCMC that extends +Hamiltonian Monte Carlo by adapting trajectory length on the fly, so the user does not have to +hand-tune the leapfrog step count. Because it uses gradients, the run is typically much more +sample-efficient than the ensemble samplers above on smooth, unimodal posteriors. + +Two requirements distinguish it from `Emcee` / `Zeus`: + + 1) The analysis must be built with `use_jax=True` so the log-likelihood is JAX-traceable end to + end (`jax.grad` of it has to be takeable). Below we construct a separate `analysis_jax` to + keep the existing `analysis` (used by Emcee / Zeus) on its NumPy path. + 2) The user's model must be registered as a JAX pytree so `model.instance_from_vector` flows + through `jax.jit`. `enable_pytrees()` is a one-shot process-level call; `register_model(model)` + walks the user's model and registers each concrete class it finds (here, `af.ex.Gaussian`). + This is the same mechanism the `Nautilus_jax` example uses. + +The fit itself runs in two phases: + + 1) `blackjax.window_adaptation` (warmup) — tunes the leapfrog step size via dual averaging and + a diagonal inverse mass matrix from the warmup covariance. + 2) NUTS sampling with the tuned kernel, run inside a JIT-compiled `jax.lax.scan` so the inner + per-step kernel runs as a single fused XLA computation. + +`blackjax` is an optional dependency — install with `pip install autofit[optional]` (which now +pulls it in alongside `nautilus-sampler` etc.) or directly with `pip install blackjax`. + +Relevant links: + + - BlackJAX: https://github.com/blackjax-devs/blackjax + - BlackJAX docs: https://blackjax-devs.github.io/blackjax/ + - The No-U-Turn paper: https://arxiv.org/abs/1111.4246 + +If you use `BlackJAXNUTS` as part of a published work, please cite the BlackJAX package following +the instructions on its GitHub page. +""" +from autofit.jax.pytrees import enable_pytrees, register_model + +enable_pytrees() +register_model(model) + +analysis_jax = af.ex.Analysis(data=data, noise_map=noise_map, use_jax=True) + +search = af.BlackJAXNUTS( + path_prefix="searches", + name="BlackJAXNUTS", + num_warmup=500, + num_samples=1000, + target_accept=0.8, +) + +result = search.fit(model=model, analysis=analysis_jax) + +model_data = result.max_log_likelihood_instance.model_data_from( + xvalues=np.arange(data.shape[0]) +) + +plt.errorbar( + x=range(data.shape[0]), + y=data, + yerr=noise_map, + linestyle="", + color="k", + ecolor="k", + elinewidth=1, + capsize=2, +) +plt.plot(range(data.shape[0]), model_data, color="r") +plt.title("BlackJAXNUTS model fit to 1D Gaussian dataset.") +plt.xlabel("x values of profile") +plt.ylabel("Profile normalization") +plt.show() +plt.close() + +""" +The `samples_info` dict on `result.samples` exposes NUTS-specific diagnostics that don't apply to +the ensemble samplers above: + + - `ess_min` / `ess_per_param`: effective sample size from BlackJAX's Geyer-style estimator. + Higher is better; values close to `num_samples` mean the chain is nearly independent. + - `mean_acceptance`: average Metropolis acceptance over the post-warmup chain. Should land + close to `target_accept` (0.8 by default) once warmup has converged. + - `n_divergent`: number of divergent transitions post-warmup. A non-zero count usually means + the step size is too aggressive or the posterior has narrow funnels — re-run with a + tighter `target_accept` (e.g. 0.95) if you see them. + - `n_logl_evals`: total leapfrog integration steps summed across the chain — the right + "cost-per-sample" denominator for comparing NUTS to ensemble methods. +""" +info = result.samples.samples_info +print(f"ESS (min over dims): {info['ess_min']:.1f} / {info['num_samples']}") +print(f"Mean acceptance: {info['mean_acceptance']:.3f} (target {0.8:.2f})") +print(f"Divergences: {info['n_divergent']}") +print(f"Total leapfrog evals: {info['n_logl_evals']}")