Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions autofit/non_linear/search/nest/nss/_chunked_nss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Local replica of ``blackjax.ns.nss.as_top_level_api`` with chunked init.

PyAutoFit#1303 added ``make_chunked_update_strategy`` for the per-iteration
MCMC step path inside ``blackjax.ns.from_mcmc.update_with_mcmc_take_last``.
That covers the inner vmap over ``num_delete`` particles, but **not** the
separate hardcoded ``jax.vmap(init_state_fn)`` inside
``blackjax.ns.nss.as_top_level_api``'s ``init_fn``
(``blackjax/ns/nss.py:223-230`` in the handley-lab fork) — and that's where
inversion-heavy lensing cells (PyAutoLens pixelization / Delaunay at HST
scale) OOM on A100 80 GB before the sampling loop even starts.

This module replaces ``_blackjax.nss(...)`` with a local builder so we
control both seams:

- The chunked update_strategy goes through ``make_chunked_update_strategy``
(PyAutoFit#1303). Behaviour unchanged from that PR.
- The chunked init swaps ``jax.vmap(init_state_fn)`` for
``jax.lax.map(init_state_fn, positions, batch_size=chunk_size)`` so peak
GPU memory during ``algo.init`` becomes ``chunk_size × per_particle_state``
instead of ``n_live × per_particle_state``.

When ``chunk_size`` is None the builder still uses ``jax.vmap`` and is
bit-identical to upstream. The builder otherwise produces a
``blackjax.SamplingAlgorithm`` with the same shape ``_blackjax.nss(...)``
returns, so ``af.NSS._fit`` is a one-line switch.

See PyAutoFit#1304 for the diagnosis and A100 evidence (jobs 322605 /
322606 OOM at the same byte counts as before #1303 landed, because the
crash is in ``algo.init`` not ``algo.step``).
"""

from __future__ import annotations

from functools import partial
from typing import Callable, Optional


def build_chunked_nss_algorithm(
*,
logprior_fn: Callable,
loglikelihood_fn: Callable,
num_inner_steps: int,
num_delete: int,
chunk_size: Optional[int],
):
"""Return a ``blackjax.SamplingAlgorithm`` with chunked init + step paths.

Replicates the body of ``blackjax.ns.nss.as_top_level_api`` (the
handley-lab fork) so we can plug in chunked variants of the two
vmap sites — the inner MCMC step (via the existing
``make_chunked_update_strategy``) and the n_live-wide init.

Parameters
----------
logprior_fn
Log-prior callable, ``positions -> scalar log-prior``.
loglikelihood_fn
Log-likelihood callable, ``positions -> scalar log-L``.
num_inner_steps
Number of HRSS steps per particle replacement (matches
``af.NSS.num_mcmc_steps``).
num_delete
Number of particles replaced per outer iteration (matches
``af.NSS.num_delete``).
chunk_size
Optional GPU-memory knob. When None, both vmap sites use plain
``jax.vmap`` and the result is bit-identical to upstream
``blackjax.nss(...)``. When set, peak memory in each site becomes
``chunk_size × per_particle_state``.
"""
# Local imports keep this module cheap to import when ``af.NSS`` is
# never used (blackjax + jax are optional deps gated by the
# ``[nss]`` extra; see PyAutoFit pyproject.toml).
import jax
from blackjax import SamplingAlgorithm
from blackjax.ns.adaptive import init as ns_init
from blackjax.ns.base import init_state_strategy
from blackjax.ns.nss import build_kernel, update_inner_kernel_params

from autofit.non_linear.search.nest.nss._chunked_update import (
make_chunked_update_strategy,
)

init_state_fn = partial(
init_state_strategy,
logprior_fn=logprior_fn,
loglikelihood_fn=loglikelihood_fn,
)

kernel = build_kernel(
init_state_fn,
num_inner_steps,
num_delete,
update_strategy=make_chunked_update_strategy(chunk_size),
)

def init_fn(position, rng_key=None):
# Mirror the upstream signature (rng_key is unused but accepted for
# API parity with ``blackjax.ns.nss.as_top_level_api``).
if chunk_size is None:
init_batcher = jax.vmap(init_state_fn)
else:
init_batcher = lambda p: jax.lax.map(
init_state_fn, p, batch_size=chunk_size
)
return ns_init(
position,
init_state_fn=init_batcher,
update_inner_kernel_params_fn=update_inner_kernel_params,
)

def step_fn(rng_key, state):
return kernel(rng_key, state)

return SamplingAlgorithm(init_fn, step_fn)
48 changes: 29 additions & 19 deletions autofit/non_linear/search/nest/nss/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,28 +355,38 @@ def prior_logprob(params):
]
)

nss_kwargs = dict(
logprior_fn=prior_logprob,
loglikelihood_fn=log_likelihood,
num_delete=self.num_delete,
num_inner_steps=self.num_mcmc_steps,
)
# When ``chunk_size`` is set and below ``num_delete``, swap blackjax's
# default ``update_with_mcmc_take_last`` for a chunked variant whose
# inner vmap becomes ``jax.lax.map(batch_size=chunk_size)`` — see
# ``_chunked_update.py`` for the rationale and PyAutoFit#1301 for
# per-A100-cell evidence. ``chunk_size=None`` / ``>= num_delete`` are
# no-ops (the chunked builder still falls back to ``jax.vmap``).
if self.chunk_size is not None and self.chunk_size < self.num_delete:
from autofit.non_linear.search.nest.nss._chunked_update import (
make_chunked_update_strategy,
# When ``chunk_size`` is set and below the wider of ``n_live`` /
# ``num_delete``, build the algorithm via the PyAutoFit-local
# ``build_chunked_nss_algorithm``. That replicates
# ``blackjax.ns.nss.as_top_level_api`` (~30 lines) with both vmap
# sites chunked: the inner MCMC step (matches PyAutoFit#1303's
# ``make_chunked_update_strategy``) **and** the n_live-wide
# ``algo.init`` (PyAutoFit#1304 — the OOM-causing site for
# inversion-heavy lensing cells before this PR). ``chunk_size=None``
# or ``chunk_size >= max(n_live, num_delete)`` keeps using upstream
# ``blackjax.nss(...)`` bit-for-bit.
if (
self.chunk_size is not None
and self.chunk_size < max(self.n_live, self.num_delete)
):
from autofit.non_linear.search.nest.nss._chunked_nss import (
build_chunked_nss_algorithm,
)

nss_kwargs["update_strategy"] = make_chunked_update_strategy(
self.chunk_size
algo = build_chunked_nss_algorithm(
logprior_fn=prior_logprob,
loglikelihood_fn=log_likelihood,
num_inner_steps=self.num_mcmc_steps,
num_delete=self.num_delete,
chunk_size=self.chunk_size,
)
else:
algo = _blackjax.nss(
logprior_fn=prior_logprob,
loglikelihood_fn=log_likelihood,
num_delete=self.num_delete,
num_inner_steps=self.num_mcmc_steps,
)

algo = _blackjax.nss(**nss_kwargs)

@jax.jit
def one_step(carry, _):
Expand Down
36 changes: 36 additions & 0 deletions test_autofit/non_linear/search/nest/nss/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,42 @@ def test__chunked_update_strategy_factory():
]


def test__chunked_nss_algorithm_factory():
"""``build_chunked_nss_algorithm`` returns a ``blackjax.SamplingAlgorithm``-
shape NamedTuple with ``init`` and ``step`` attributes. This is what
``af.NSS._fit`` relies on as a drop-in for ``_blackjax.nss(...)``.

No JAX execution — library policy keeps JAX-traced tests in
``autofit_workspace_test``.
"""
blackjax = pytest.importorskip("blackjax")
from autofit.non_linear.search.nest.nss._chunked_nss import (
build_chunked_nss_algorithm,
)

algo = build_chunked_nss_algorithm(
logprior_fn=lambda p: 0.0,
loglikelihood_fn=lambda p: 0.0,
num_inner_steps=3,
num_delete=4,
chunk_size=2,
)

assert isinstance(algo, blackjax.SamplingAlgorithm)
assert callable(algo.init)
assert callable(algo.step)

# chunk_size=None → still returns the same shape (unchunked path).
algo_unchunked = build_chunked_nss_algorithm(
logprior_fn=lambda p: 0.0,
loglikelihood_fn=lambda p: 0.0,
num_inner_steps=3,
num_delete=4,
chunk_size=None,
)
assert isinstance(algo_unchunked, blackjax.SamplingAlgorithm)


def test__identifier_fields():
search = af.NSS()
for field in ("n_live", "num_mcmc_steps", "num_delete", "termination", "seed"):
Expand Down
Loading