From 4276277d4d453d81a9c5850417e3c27a06dd7317 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 29 May 2026 09:43:39 +0100 Subject: [PATCH] fix(nss): chunked algo.init follow-up to #1303 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PyAutoFit#1303 chunked the per-iteration MCMC step's jax.vmap(num_delete) but left a separate hardcoded jax.vmap(init_state_fn) inside blackjax's nss.as_top_level_api init_fn unchunked. A100 retry of the cells #1303 was supposed to unblock (autolens_profiling NSS pixelization + delaunay × HST × fp64, jobs 322605 + 322606) OOM at the same byte counts as before #1303 (28.05 GB pix, 27.67 GB delaunay); the "NSS configuration:" log line never appears, confirming the crash is in algo.init not algo.step. This PR ships the missing init-side chunking. New module autofit/non_linear/search/nest/nss/_chunked_nss.py exposes build_chunked_nss_algorithm — a ~30-line local replica of blackjax.nss.as_top_level_api that controls both vmap sites: - step path: make_chunked_update_strategy from #1303 (unchanged) - init path: jax.lax.map(init_state_fn, positions, batch_size=chunk_size) instead of jax.vmap(init_state_fn) af.NSS._fit switches to the local builder when chunk_size < max(n_live, num_delete). chunk_size=None or chunk_size >= max(n_live, num_delete) keeps using upstream blackjax.nss(...) bit-for-bit. 5D Gaussian smoke at n_live=20 (5 init chunks of 4): bit-identical log_Z = -4.3208251152 between chunk_size=None and chunk_size=4. Both configuration log lines print the chunk_size value, confirming both code paths fired. Refs #1301, #1303, #1304 Co-Authored-By: Claude Opus 4.7 (1M context) --- .../search/nest/nss/_chunked_nss.py | 115 ++++++++++++++++++ autofit/non_linear/search/nest/nss/search.py | 48 +++++--- .../non_linear/search/nest/nss/test_search.py | 36 ++++++ 3 files changed, 180 insertions(+), 19 deletions(-) create mode 100644 autofit/non_linear/search/nest/nss/_chunked_nss.py diff --git a/autofit/non_linear/search/nest/nss/_chunked_nss.py b/autofit/non_linear/search/nest/nss/_chunked_nss.py new file mode 100644 index 000000000..d23ee7a34 --- /dev/null +++ b/autofit/non_linear/search/nest/nss/_chunked_nss.py @@ -0,0 +1,115 @@ +"""Local replica of ``blackjax.ns.nss.as_top_level_api`` with chunked init. + +PyAutoFit#1303 added ``make_chunked_update_strategy`` for the per-iteration +MCMC step path inside ``blackjax.ns.from_mcmc.update_with_mcmc_take_last``. +That covers the inner vmap over ``num_delete`` particles, but **not** the +separate hardcoded ``jax.vmap(init_state_fn)`` inside +``blackjax.ns.nss.as_top_level_api``'s ``init_fn`` +(``blackjax/ns/nss.py:223-230`` in the handley-lab fork) — and that's where +inversion-heavy lensing cells (PyAutoLens pixelization / Delaunay at HST +scale) OOM on A100 80 GB before the sampling loop even starts. + +This module replaces ``_blackjax.nss(...)`` with a local builder so we +control both seams: + +- The chunked update_strategy goes through ``make_chunked_update_strategy`` + (PyAutoFit#1303). Behaviour unchanged from that PR. +- The chunked init swaps ``jax.vmap(init_state_fn)`` for + ``jax.lax.map(init_state_fn, positions, batch_size=chunk_size)`` so peak + GPU memory during ``algo.init`` becomes ``chunk_size × per_particle_state`` + instead of ``n_live × per_particle_state``. + +When ``chunk_size`` is None the builder still uses ``jax.vmap`` and is +bit-identical to upstream. The builder otherwise produces a +``blackjax.SamplingAlgorithm`` with the same shape ``_blackjax.nss(...)`` +returns, so ``af.NSS._fit`` is a one-line switch. + +See PyAutoFit#1304 for the diagnosis and A100 evidence (jobs 322605 / +322606 OOM at the same byte counts as before #1303 landed, because the +crash is in ``algo.init`` not ``algo.step``). +""" + +from __future__ import annotations + +from functools import partial +from typing import Callable, Optional + + +def build_chunked_nss_algorithm( + *, + logprior_fn: Callable, + loglikelihood_fn: Callable, + num_inner_steps: int, + num_delete: int, + chunk_size: Optional[int], +): + """Return a ``blackjax.SamplingAlgorithm`` with chunked init + step paths. + + Replicates the body of ``blackjax.ns.nss.as_top_level_api`` (the + handley-lab fork) so we can plug in chunked variants of the two + vmap sites — the inner MCMC step (via the existing + ``make_chunked_update_strategy``) and the n_live-wide init. + + Parameters + ---------- + logprior_fn + Log-prior callable, ``positions -> scalar log-prior``. + loglikelihood_fn + Log-likelihood callable, ``positions -> scalar log-L``. + num_inner_steps + Number of HRSS steps per particle replacement (matches + ``af.NSS.num_mcmc_steps``). + num_delete + Number of particles replaced per outer iteration (matches + ``af.NSS.num_delete``). + chunk_size + Optional GPU-memory knob. When None, both vmap sites use plain + ``jax.vmap`` and the result is bit-identical to upstream + ``blackjax.nss(...)``. When set, peak memory in each site becomes + ``chunk_size × per_particle_state``. + """ + # Local imports keep this module cheap to import when ``af.NSS`` is + # never used (blackjax + jax are optional deps gated by the + # ``[nss]`` extra; see PyAutoFit pyproject.toml). + import jax + from blackjax import SamplingAlgorithm + from blackjax.ns.adaptive import init as ns_init + from blackjax.ns.base import init_state_strategy + from blackjax.ns.nss import build_kernel, update_inner_kernel_params + + from autofit.non_linear.search.nest.nss._chunked_update import ( + make_chunked_update_strategy, + ) + + init_state_fn = partial( + init_state_strategy, + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + ) + + kernel = build_kernel( + init_state_fn, + num_inner_steps, + num_delete, + update_strategy=make_chunked_update_strategy(chunk_size), + ) + + def init_fn(position, rng_key=None): + # Mirror the upstream signature (rng_key is unused but accepted for + # API parity with ``blackjax.ns.nss.as_top_level_api``). + if chunk_size is None: + init_batcher = jax.vmap(init_state_fn) + else: + init_batcher = lambda p: jax.lax.map( + init_state_fn, p, batch_size=chunk_size + ) + return ns_init( + position, + init_state_fn=init_batcher, + update_inner_kernel_params_fn=update_inner_kernel_params, + ) + + def step_fn(rng_key, state): + return kernel(rng_key, state) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/autofit/non_linear/search/nest/nss/search.py b/autofit/non_linear/search/nest/nss/search.py index 384928b04..744d78ba7 100644 --- a/autofit/non_linear/search/nest/nss/search.py +++ b/autofit/non_linear/search/nest/nss/search.py @@ -355,28 +355,38 @@ def prior_logprob(params): ] ) - nss_kwargs = dict( - logprior_fn=prior_logprob, - loglikelihood_fn=log_likelihood, - num_delete=self.num_delete, - num_inner_steps=self.num_mcmc_steps, - ) - # When ``chunk_size`` is set and below ``num_delete``, swap blackjax's - # default ``update_with_mcmc_take_last`` for a chunked variant whose - # inner vmap becomes ``jax.lax.map(batch_size=chunk_size)`` — see - # ``_chunked_update.py`` for the rationale and PyAutoFit#1301 for - # per-A100-cell evidence. ``chunk_size=None`` / ``>= num_delete`` are - # no-ops (the chunked builder still falls back to ``jax.vmap``). - if self.chunk_size is not None and self.chunk_size < self.num_delete: - from autofit.non_linear.search.nest.nss._chunked_update import ( - make_chunked_update_strategy, + # When ``chunk_size`` is set and below the wider of ``n_live`` / + # ``num_delete``, build the algorithm via the PyAutoFit-local + # ``build_chunked_nss_algorithm``. That replicates + # ``blackjax.ns.nss.as_top_level_api`` (~30 lines) with both vmap + # sites chunked: the inner MCMC step (matches PyAutoFit#1303's + # ``make_chunked_update_strategy``) **and** the n_live-wide + # ``algo.init`` (PyAutoFit#1304 — the OOM-causing site for + # inversion-heavy lensing cells before this PR). ``chunk_size=None`` + # or ``chunk_size >= max(n_live, num_delete)`` keeps using upstream + # ``blackjax.nss(...)`` bit-for-bit. + if ( + self.chunk_size is not None + and self.chunk_size < max(self.n_live, self.num_delete) + ): + from autofit.non_linear.search.nest.nss._chunked_nss import ( + build_chunked_nss_algorithm, ) - nss_kwargs["update_strategy"] = make_chunked_update_strategy( - self.chunk_size + algo = build_chunked_nss_algorithm( + logprior_fn=prior_logprob, + loglikelihood_fn=log_likelihood, + num_inner_steps=self.num_mcmc_steps, + num_delete=self.num_delete, + chunk_size=self.chunk_size, + ) + else: + algo = _blackjax.nss( + logprior_fn=prior_logprob, + loglikelihood_fn=log_likelihood, + num_delete=self.num_delete, + num_inner_steps=self.num_mcmc_steps, ) - - algo = _blackjax.nss(**nss_kwargs) @jax.jit def one_step(carry, _): diff --git a/test_autofit/non_linear/search/nest/nss/test_search.py b/test_autofit/non_linear/search/nest/nss/test_search.py index a81c78b69..da3d4ea42 100644 --- a/test_autofit/non_linear/search/nest/nss/test_search.py +++ b/test_autofit/non_linear/search/nest/nss/test_search.py @@ -70,6 +70,42 @@ def test__chunked_update_strategy_factory(): ] +def test__chunked_nss_algorithm_factory(): + """``build_chunked_nss_algorithm`` returns a ``blackjax.SamplingAlgorithm``- + shape NamedTuple with ``init`` and ``step`` attributes. This is what + ``af.NSS._fit`` relies on as a drop-in for ``_blackjax.nss(...)``. + + No JAX execution — library policy keeps JAX-traced tests in + ``autofit_workspace_test``. + """ + blackjax = pytest.importorskip("blackjax") + from autofit.non_linear.search.nest.nss._chunked_nss import ( + build_chunked_nss_algorithm, + ) + + algo = build_chunked_nss_algorithm( + logprior_fn=lambda p: 0.0, + loglikelihood_fn=lambda p: 0.0, + num_inner_steps=3, + num_delete=4, + chunk_size=2, + ) + + assert isinstance(algo, blackjax.SamplingAlgorithm) + assert callable(algo.init) + assert callable(algo.step) + + # chunk_size=None → still returns the same shape (unchunked path). + algo_unchunked = build_chunked_nss_algorithm( + logprior_fn=lambda p: 0.0, + loglikelihood_fn=lambda p: 0.0, + num_inner_steps=3, + num_delete=4, + chunk_size=None, + ) + assert isinstance(algo_unchunked, blackjax.SamplingAlgorithm) + + def test__identifier_fields(): search = af.NSS() for field in ("n_live", "num_mcmc_steps", "num_delete", "termination", "seed"):