diff --git a/autofit/non_linear/search/nest/nss/_chunked_update.py b/autofit/non_linear/search/nest/nss/_chunked_update.py new file mode 100644 index 000000000..5a226a82e --- /dev/null +++ b/autofit/non_linear/search/nest/nss/_chunked_update.py @@ -0,0 +1,119 @@ +"""Chunked replacement for ``blackjax.ns.from_mcmc.update_with_mcmc_take_last``. + +Upstream blackjax fans out ``num_delete`` particles through ``jax.vmap`` +with no chunking: + + sample_keys = jax.random.split(sample_key, num_delete) + return jax.vmap(mcmc_kernel)(sample_keys, start_state) + +On inversion-heavy likelihoods (e.g. PyAutoLens pixelization / Delaunay +source models) the per-particle MCMC state plus scatter temp buffers +exceeds A100 80 GB even at ``num_delete=16``. See PyAutoFit#1301 for the +full diagnosis and per-cell evidence from ``autolens_profiling``. + +``chunked_update_with_mcmc_take_last`` accepts a ``chunk_size`` kwarg and +swaps the vmap for ``jax.lax.map(..., batch_size=chunk_size)`` when +``chunk_size < num_delete`` — same vmap parallelism within a chunk, sequential +chunks across. Peak memory becomes ``chunk_size × per_particle_state`` +instead of ``num_delete × per_particle_state``. + +When ``chunk_size`` is None or ``>= num_delete`` the function is +bit-identical to upstream. + +``blackjax.nss(...)`` already exposes ``update_strategy`` as a kwarg +(see ``blackjax/ns/nss.py:157``), so ``af.NSS._fit`` only needs to pass +this builder to opt in: + + algo = _blackjax.nss( + ..., + update_strategy=make_chunked_update_strategy(chunk_size), + ) +""" + +from __future__ import annotations + +from functools import partial +from typing import Callable, Optional + + +def make_chunked_update_strategy(chunk_size: Optional[int]) -> Callable: + """Return an ``update_strategy`` callable for ``blackjax.nss(...)``. + + Signature matches ``blackjax.ns.from_mcmc.update_with_mcmc_take_last`` + so it can be passed through the ``update_strategy=`` kwarg unmodified. + + Parameters + ---------- + chunk_size + Number of particles to vmap-batch per chunk. When None or + ``>= num_delete`` the chunked path is skipped and the function + falls through to a plain ``jax.vmap`` (matching upstream + behaviour bit-for-bit). + """ + + def chunked_update_with_mcmc_take_last( + constrained_mcmc_step_fn, + num_mcmc_steps, + num_delete, + ): + """Drop-in for ``blackjax.ns.from_mcmc.update_with_mcmc_take_last``. + + Identical to upstream except the inner + ``jax.vmap(mcmc_kernel)(sample_keys, start_state)`` is replaced + with ``jax.lax.map(..., batch_size=chunk_size)`` when + ``chunk_size`` is set and smaller than ``num_delete``. + """ + import jax + import jax.numpy as jnp + + def update_function(rng_key, state, loglikelihood_0, **step_parameters): + choice_key, sample_key = jax.random.split(rng_key) + particles = state.particles + + # Select start particles from survivors (verbatim from upstream). + weights = (particles.loglikelihood > loglikelihood_0).astype(jnp.float32) + weights = jnp.where(weights.sum() > 0.0, weights, jnp.ones_like(weights)) + start_idx = jax.random.choice( + choice_key, + len(weights), + shape=(num_delete,), + p=weights / weights.sum(), + replace=True, + ) + start_state = jax.tree.map(lambda x: x[start_idx], particles) + + shared_mcmc_step_fn = partial( + constrained_mcmc_step_fn, + loglikelihood_0=loglikelihood_0, + **step_parameters, + ) + + def mcmc_kernel(rng_key, state): + keys = jax.random.split(rng_key, num_mcmc_steps) + + def body_fn(state, rng_key): + new_state, info = shared_mcmc_step_fn(rng_key, state) + return new_state, info + + final_state, infos = jax.lax.scan(body_fn, state, keys) + return final_state, infos + + sample_keys = jax.random.split(sample_key, num_delete) + + # Fall through to bit-identical upstream behaviour when the + # user hasn't asked for chunking, or when the requested chunk + # already covers every particle. + if chunk_size is None or chunk_size >= num_delete: + return jax.vmap(mcmc_kernel)(sample_keys, start_state) + + # Chunked path: jax.lax.map(batch_size=k) vmaps within each + # chunk-of-k particles and loops across chunks. + return jax.lax.map( + lambda xs: mcmc_kernel(xs[0], xs[1]), + (sample_keys, start_state), + batch_size=chunk_size, + ) + + return update_function + + return chunked_update_with_mcmc_take_last diff --git a/autofit/non_linear/search/nest/nss/search.py b/autofit/non_linear/search/nest/nss/search.py index c91919ef5..384928b04 100644 --- a/autofit/non_linear/search/nest/nss/search.py +++ b/autofit/non_linear/search/nest/nss/search.py @@ -139,6 +139,7 @@ def __init__( n_live: int = 200, num_mcmc_steps: int = 5, num_delete: int = 50, + chunk_size: Optional[int] = None, termination: float = -3.0, checkpoint_interval: int = 100, iterations_per_quick_update: Optional[int] = None, @@ -195,6 +196,17 @@ def __init__( Number of particles killed per outer iteration. Larger ``num_delete`` reduces JIT overhead per iteration at the cost of slightly worse posterior coverage. + chunk_size + Optional GPU-memory knob. When set and ``< num_delete``, the + inner MCMC step vmap (which fans out ``num_delete`` particles + in parallel inside ``blackjax.ns.from_mcmc.update_with_mcmc_take_last``) + is replaced with ``jax.lax.map(..., batch_size=chunk_size)``. + Peak GPU memory becomes ``chunk_size × per_particle_state`` + instead of ``num_delete × per_particle_state`` — required to + run NSS on inversion-heavy likelihoods (PyAutoLens pixelization + / Delaunay) at A100 80 GB scale. Default ``None`` preserves the + upstream un-chunked behaviour (and is the right choice on CPU + or whenever ``num_delete`` already fits the device). termination Convergence criterion. The fit stops when ``logZ_live - logZ < termination``. Default ``-3.0`` corresponds @@ -262,6 +274,7 @@ def __init__( self.n_live = n_live self.num_mcmc_steps = num_mcmc_steps self.num_delete = num_delete + self.chunk_size = chunk_size self.termination = termination self.checkpoint_interval = checkpoint_interval self.seed = seed @@ -342,12 +355,28 @@ def prior_logprob(params): ] ) - algo = _blackjax.nss( + nss_kwargs = dict( logprior_fn=prior_logprob, loglikelihood_fn=log_likelihood, num_delete=self.num_delete, num_inner_steps=self.num_mcmc_steps, ) + # When ``chunk_size`` is set and below ``num_delete``, swap blackjax's + # default ``update_with_mcmc_take_last`` for a chunked variant whose + # inner vmap becomes ``jax.lax.map(batch_size=chunk_size)`` — see + # ``_chunked_update.py`` for the rationale and PyAutoFit#1301 for + # per-A100-cell evidence. ``chunk_size=None`` / ``>= num_delete`` are + # no-ops (the chunked builder still falls back to ``jax.vmap``). + if self.chunk_size is not None and self.chunk_size < self.num_delete: + from autofit.non_linear.search.nest.nss._chunked_update import ( + make_chunked_update_strategy, + ) + + nss_kwargs["update_strategy"] = make_chunked_update_strategy( + self.chunk_size + ) + + algo = _blackjax.nss(**nss_kwargs) @jax.jit def one_step(carry, _): @@ -370,11 +399,12 @@ def one_step(carry, _): iteration = 0 self.logger.info( "NSS configuration: n_live=%d, num_mcmc_steps=%d, num_delete=%d, " - "termination=%s, ndim=%d, checkpoint_interval=%d. JIT compile on " - "first iteration may take 25-30 s.", + "chunk_size=%s, termination=%s, ndim=%d, checkpoint_interval=%d. " + "JIT compile on first iteration may take 25-30 s.", self.n_live, self.num_mcmc_steps, self.num_delete, + self.chunk_size, self.termination, ndim, self.checkpoint_interval, diff --git a/test_autofit/non_linear/search/nest/nss/test_search.py b/test_autofit/non_linear/search/nest/nss/test_search.py index d9492ece6..a81c78b69 100644 --- a/test_autofit/non_linear/search/nest/nss/test_search.py +++ b/test_autofit/non_linear/search/nest/nss/test_search.py @@ -24,6 +24,7 @@ def test__explicit_params(): n_live=500, num_mcmc_steps=10, num_delete=20, + chunk_size=4, termination=-2.0, seed=7, ) @@ -31,6 +32,7 @@ def test__explicit_params(): assert search.n_live == 500 assert search.num_mcmc_steps == 10 assert search.num_delete == 20 + assert search.chunk_size == 4 assert search.termination == -2.0 assert search.seed == 7 @@ -38,10 +40,36 @@ def test__explicit_params(): assert default.n_live == 200 assert default.num_mcmc_steps == 5 assert default.num_delete == 50 + assert default.chunk_size is None assert default.termination == -3.0 assert default.seed == 42 +def test__chunked_update_strategy_factory(): + """``make_chunked_update_strategy`` returns a callable with the same + signature as blackjax's ``update_with_mcmc_take_last`` regardless of + whether ``chunk_size`` is set. This lets ``af.NSS._fit`` drop it into + ``blackjax.nss(update_strategy=...)`` without further branching. + """ + from autofit.non_linear.search.nest.nss._chunked_update import ( + make_chunked_update_strategy, + ) + + strategy_none = make_chunked_update_strategy(None) + strategy_chunked = make_chunked_update_strategy(4) + # Both are callables with the upstream three-arg signature + # (constrained_mcmc_step_fn, num_mcmc_steps, num_delete). + import inspect + + for strategy in (strategy_none, strategy_chunked): + params = list(inspect.signature(strategy).parameters) + assert params == [ + "constrained_mcmc_step_fn", + "num_mcmc_steps", + "num_delete", + ] + + def test__identifier_fields(): search = af.NSS() for field in ("n_live", "num_mcmc_steps", "num_delete", "termination", "seed"):