Skip to content

fix(nss): chunked algo.init follow-up to #1303#1305

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/nss-chunked-init
May 29, 2026
Merged

fix(nss): chunked algo.init follow-up to #1303#1305
Jammy2211 merged 1 commit into
mainfrom
feature/nss-chunked-init

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

PyAutoFit#1303 chunked the per-iteration MCMC step's jax.vmap(num_delete) but left a separate hardcoded jax.vmap(init_state_fn) inside blackjax.ns.nss.as_top_level_api's init_fn unchunked. A100 retry of the cells #1303 was supposed to unblock (autolens_profiling NSS pixelization + delaunay × HST × fp64, jobs 322605 + 322606) OOM at the same byte counts as before #1303 (28.05 GB pixelization, 27.67 GB delaunay). The NSS configuration: log line never appears, confirming the crash is in algo.init not algo.step.

This PR ships the missing init-side chunking. New module autofit.non_linear.search.nest.nss._chunked_nss exposes build_chunked_nss_algorithm — a ~30-line local replica of blackjax.nss.as_top_level_api that controls both vmap sites:

af.NSS._fit switches to the local builder when chunk_size < max(n_live, num_delete). chunk_size=None or chunk_size >= max(n_live, num_delete) keeps using upstream blackjax.nss(...) bit-for-bit.

Refs #1301, #1303, #1304

API Changes

No public-API changes. The existing af.NSS(chunk_size=...) kwarg (added by #1303) gains coverage of the init path that previously OOMed before reaching the configuration log line. New private module autofit.non_linear.search.nest.nss._chunked_nss is an implementation detail (_-prefixed); only af.NSS._fit imports it.

af.NSS._fit now chooses between the local builder and upstream blackjax.nss(...) based on the chunk_size threshold. Default chunk_size=None continues to call _blackjax.nss(...) bit-for-bit.

See full details below.

Test Plan

  • pytest test_autofit/non_linear/search/nest/nss/test_search.py — 11/11 pass; new test__chunked_nss_algorithm_factory confirms the factory returns a blackjax.SamplingAlgorithm with callable .init and .step.
  • JAX-traced smoke (5D Gaussian, n_live=20, num_delete=4 → init runs 5 chunks of 4 when chunked): chunk_size=None and chunk_size=4 produce bit-identical log_Z = -4.3208251152 on the same seed. Both chunk_size= values appear in the configuration log line, confirming both code paths fired.
  • A100 follow-up: resubmit searches/nss/imaging/{pixelization,delaunay} × hst × fp64 on autolens_profiling. chunk_size=16 is set automatically by build_nss (already merged in autolens_profiling#43). Expect completion within ~3× of Nautilus baselines: pixelization 46.5 ms/eval / 46 min (322603); delaunay 84.8 ms/eval / 45 min (322601).
Full API Changes (for automation & release notes)

Added

  • autofit.non_linear.search.nest.nss._chunked_nss.build_chunked_nss_algorithm(*, logprior_fn, loglikelihood_fn, num_inner_steps, num_delete, chunk_size) -> blackjax.SamplingAlgorithm — local replica of blackjax.ns.nss.as_top_level_api with chunked init AND step. Private module (_-prefixed); not part of the user-facing API.

Changed Behaviour

  • af.NSS._fit now uses build_chunked_nss_algorithm(...) instead of _blackjax.nss(...) when self.chunk_size is not None and self.chunk_size < max(self.n_live, self.num_delete). The threshold widens from < num_delete (feat(nss): chunk_size kwarg for inversion-heavy A100 likelihoods #1303) to < max(n_live, num_delete) so the n_live-wide init path also benefits. chunk_size=None and chunk_size >= max(n_live, num_delete) continue to call _blackjax.nss(...) bit-for-bit.

Migration

  • None. The user-facing af.NSS(chunk_size=...) API is unchanged; this PR only fixes the underlying execution.

Identifier hash

🤖 Generated with Claude Code

PyAutoFit#1303 chunked the per-iteration MCMC step's jax.vmap(num_delete)
but left a separate hardcoded jax.vmap(init_state_fn) inside blackjax's
nss.as_top_level_api init_fn unchunked. A100 retry of the cells #1303
was supposed to unblock (autolens_profiling NSS pixelization + delaunay
× HST × fp64, jobs 322605 + 322606) OOM at the same byte counts as
before #1303 (28.05 GB pix, 27.67 GB delaunay); the "NSS configuration:"
log line never appears, confirming the crash is in algo.init not
algo.step.

This PR ships the missing init-side chunking. New module
autofit/non_linear/search/nest/nss/_chunked_nss.py exposes
build_chunked_nss_algorithm — a ~30-line local replica of
blackjax.nss.as_top_level_api that controls both vmap sites:

- step path: make_chunked_update_strategy from #1303 (unchanged)
- init path: jax.lax.map(init_state_fn, positions, batch_size=chunk_size)
  instead of jax.vmap(init_state_fn)

af.NSS._fit switches to the local builder when
chunk_size < max(n_live, num_delete). chunk_size=None or
chunk_size >= max(n_live, num_delete) keeps using upstream blackjax.nss(...)
bit-for-bit.

5D Gaussian smoke at n_live=20 (5 init chunks of 4): bit-identical
log_Z = -4.3208251152 between chunk_size=None and chunk_size=4. Both
configuration log lines print the chunk_size value, confirming both
code paths fired.

Refs #1301, #1303, #1304

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 29, 2026
@Jammy2211 Jammy2211 merged commit 2986bbf into main May 29, 2026
7 checks passed
@Jammy2211 Jammy2211 deleted the feature/nss-chunked-init branch May 29, 2026 08:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant