Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 101 additions & 3 deletions scripts/searches/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
Searches: MCMC
==============

This example illustrates how to use the MCMC ensemble sampler algorithms supported by **PyAutoFit**:
This example illustrates how to use the MCMC sampler algorithms supported by **PyAutoFit**:

- `Emcee`: The emcee ensemble sampler.
- `Zeus`: The zeus ensemble slice sampler.
- `BlackJAXNUTS`: BlackJAX's No-U-Turn Sampler — gradient-based MCMC requiring `use_jax=True`.

Relevant links:

- Emcee: https://emcee.readthedocs.io/en/stable/
- Zeus: https://zeus-mcmc.readthedocs.io/en/latest/
- Emcee: https://emcee.readthedocs.io/en/stable/
- Zeus: https://zeus-mcmc.readthedocs.io/en/latest/
- BlackJAX: https://github.com/blackjax-devs/blackjax

__Contents__

Expand All @@ -20,6 +22,7 @@
- **Model + Analysis**: Setting up the model and analysis shared by every search below.
- **Search: Emcee**: Configuring and running the Emcee sampler.
- **Search: Zeus**: Configuring and running the Zeus sampler.
- **Search: BlackJAXNUTS**: Configuring and running BlackJAX's NUTS sampler (requires `use_jax=True`).
- **Search Internal**: Accessing the internal sampler for advanced use (shown once for Emcee).
"""

Expand Down Expand Up @@ -224,3 +227,98 @@
plt.ylabel("Profile normalization")
plt.show()
plt.close()

"""
__Search: BlackJAXNUTS__

`BlackJAXNUTS` is the No-U-Turn Sampler from BlackJAX — a gradient-based MCMC that extends
Hamiltonian Monte Carlo by adapting trajectory length on the fly, so the user does not have to
hand-tune the leapfrog step count. Because it uses gradients, the run is typically much more
sample-efficient than the ensemble samplers above on smooth, unimodal posteriors.

Two requirements distinguish it from `Emcee` / `Zeus`:

1) The analysis must be built with `use_jax=True` so the log-likelihood is JAX-traceable end to
end (`jax.grad` of it has to be takeable). Below we construct a separate `analysis_jax` to
keep the existing `analysis` (used by Emcee / Zeus) on its NumPy path.
2) The user's model must be registered as a JAX pytree so `model.instance_from_vector` flows
through `jax.jit`. `enable_pytrees()` is a one-shot process-level call; `register_model(model)`
walks the user's model and registers each concrete class it finds (here, `af.ex.Gaussian`).
This is the same mechanism the `Nautilus_jax` example uses.

The fit itself runs in two phases:

1) `blackjax.window_adaptation` (warmup) — tunes the leapfrog step size via dual averaging and
a diagonal inverse mass matrix from the warmup covariance.
2) NUTS sampling with the tuned kernel, run inside a JIT-compiled `jax.lax.scan` so the inner
per-step kernel runs as a single fused XLA computation.

`blackjax` is an optional dependency — install with `pip install autofit[optional]` (which now
pulls it in alongside `nautilus-sampler` etc.) or directly with `pip install blackjax`.

Relevant links:

- BlackJAX: https://github.com/blackjax-devs/blackjax
- BlackJAX docs: https://blackjax-devs.github.io/blackjax/
- The No-U-Turn paper: https://arxiv.org/abs/1111.4246

If you use `BlackJAXNUTS` as part of a published work, please cite the BlackJAX package following
the instructions on its GitHub page.
"""
from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()
register_model(model)

analysis_jax = af.ex.Analysis(data=data, noise_map=noise_map, use_jax=True)

search = af.BlackJAXNUTS(
path_prefix="searches",
name="BlackJAXNUTS",
num_warmup=500,
num_samples=1000,
target_accept=0.8,
)

result = search.fit(model=model, analysis=analysis_jax)

model_data = result.max_log_likelihood_instance.model_data_from(
xvalues=np.arange(data.shape[0])
)

plt.errorbar(
x=range(data.shape[0]),
y=data,
yerr=noise_map,
linestyle="",
color="k",
ecolor="k",
elinewidth=1,
capsize=2,
)
plt.plot(range(data.shape[0]), model_data, color="r")
plt.title("BlackJAXNUTS model fit to 1D Gaussian dataset.")
plt.xlabel("x values of profile")
plt.ylabel("Profile normalization")
plt.show()
plt.close()

"""
The `samples_info` dict on `result.samples` exposes NUTS-specific diagnostics that don't apply to
the ensemble samplers above:

- `ess_min` / `ess_per_param`: effective sample size from BlackJAX's Geyer-style estimator.
Higher is better; values close to `num_samples` mean the chain is nearly independent.
- `mean_acceptance`: average Metropolis acceptance over the post-warmup chain. Should land
close to `target_accept` (0.8 by default) once warmup has converged.
- `n_divergent`: number of divergent transitions post-warmup. A non-zero count usually means
the step size is too aggressive or the posterior has narrow funnels — re-run with a
tighter `target_accept` (e.g. 0.95) if you see them.
- `n_logl_evals`: total leapfrog integration steps summed across the chain — the right
"cost-per-sample" denominator for comparing NUTS to ensemble methods.
"""
info = result.samples.samples_info
print(f"ESS (min over dims): {info['ess_min']:.1f} / {info['num_samples']}")
print(f"Mean acceptance: {info['mean_acceptance']:.3f} (target {0.8:.2f})")
print(f"Divergences: {info['n_divergent']}")
print(f"Total leapfrog evals: {info['n_logl_evals']}")
Loading