fix(nss): chunked algo.init follow-up to #1303#1305
Merged
Conversation
PyAutoFit#1303 chunked the per-iteration MCMC step's jax.vmap(num_delete) but left a separate hardcoded jax.vmap(init_state_fn) inside blackjax's nss.as_top_level_api init_fn unchunked. A100 retry of the cells #1303 was supposed to unblock (autolens_profiling NSS pixelization + delaunay × HST × fp64, jobs 322605 + 322606) OOM at the same byte counts as before #1303 (28.05 GB pix, 27.67 GB delaunay); the "NSS configuration:" log line never appears, confirming the crash is in algo.init not algo.step. This PR ships the missing init-side chunking. New module autofit/non_linear/search/nest/nss/_chunked_nss.py exposes build_chunked_nss_algorithm — a ~30-line local replica of blackjax.nss.as_top_level_api that controls both vmap sites: - step path: make_chunked_update_strategy from #1303 (unchanged) - init path: jax.lax.map(init_state_fn, positions, batch_size=chunk_size) instead of jax.vmap(init_state_fn) af.NSS._fit switches to the local builder when chunk_size < max(n_live, num_delete). chunk_size=None or chunk_size >= max(n_live, num_delete) keeps using upstream blackjax.nss(...) bit-for-bit. 5D Gaussian smoke at n_live=20 (5 init chunks of 4): bit-identical log_Z = -4.3208251152 between chunk_size=None and chunk_size=4. Both configuration log lines print the chunk_size value, confirming both code paths fired. Refs #1301, #1303, #1304 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
PyAutoFit#1303 chunked the per-iteration MCMC step's
jax.vmap(num_delete)but left a separate hardcodedjax.vmap(init_state_fn)insideblackjax.ns.nss.as_top_level_api'sinit_fnunchunked. A100 retry of the cells #1303 was supposed to unblock (autolens_profilingNSS pixelization + delaunay × HST × fp64, jobs 322605 + 322606) OOM at the same byte counts as before #1303 (28.05 GB pixelization, 27.67 GB delaunay). TheNSS configuration:log line never appears, confirming the crash is inalgo.initnotalgo.step.This PR ships the missing init-side chunking. New module
autofit.non_linear.search.nest.nss._chunked_nssexposesbuild_chunked_nss_algorithm— a ~30-line local replica ofblackjax.nss.as_top_level_apithat controls both vmap sites:make_chunked_update_strategyfrom feat(nss): chunk_size kwarg for inversion-heavy A100 likelihoods #1303 (unchanged)jax.lax.map(init_state_fn, positions, batch_size=chunk_size)instead ofjax.vmap(init_state_fn)af.NSS._fitswitches to the local builder whenchunk_size < max(n_live, num_delete).chunk_size=Noneorchunk_size >= max(n_live, num_delete)keeps using upstreamblackjax.nss(...)bit-for-bit.Refs #1301, #1303, #1304
API Changes
No public-API changes. The existing
af.NSS(chunk_size=...)kwarg (added by #1303) gains coverage of the init path that previously OOMed before reaching the configuration log line. New private moduleautofit.non_linear.search.nest.nss._chunked_nssis an implementation detail (_-prefixed); onlyaf.NSS._fitimports it.af.NSS._fitnow chooses between the local builder and upstreamblackjax.nss(...)based on the chunk_size threshold. Defaultchunk_size=Nonecontinues to call_blackjax.nss(...)bit-for-bit.See full details below.
Test Plan
pytest test_autofit/non_linear/search/nest/nss/test_search.py— 11/11 pass; newtest__chunked_nss_algorithm_factoryconfirms the factory returns ablackjax.SamplingAlgorithmwith callable.initand.step.chunk_size=Noneandchunk_size=4produce bit-identicallog_Z = -4.3208251152on the same seed. Bothchunk_size=values appear in the configuration log line, confirming both code paths fired.searches/nss/imaging/{pixelization,delaunay} × hst × fp64onautolens_profiling.chunk_size=16is set automatically bybuild_nss(already merged in autolens_profiling#43). Expect completion within ~3× of Nautilus baselines: pixelization 46.5 ms/eval / 46 min (322603); delaunay 84.8 ms/eval / 45 min (322601).Full API Changes (for automation & release notes)
Added
autofit.non_linear.search.nest.nss._chunked_nss.build_chunked_nss_algorithm(*, logprior_fn, loglikelihood_fn, num_inner_steps, num_delete, chunk_size) -> blackjax.SamplingAlgorithm— local replica ofblackjax.ns.nss.as_top_level_apiwith chunked init AND step. Private module (_-prefixed); not part of the user-facing API.Changed Behaviour
af.NSS._fitnow usesbuild_chunked_nss_algorithm(...)instead of_blackjax.nss(...)whenself.chunk_size is not None and self.chunk_size < max(self.n_live, self.num_delete). The threshold widens from< num_delete(feat(nss): chunk_size kwarg for inversion-heavy A100 likelihoods #1303) to< max(n_live, num_delete)so the n_live-wide init path also benefits.chunk_size=Noneandchunk_size >= max(n_live, num_delete)continue to call_blackjax.nss(...)bit-for-bit.Migration
af.NSS(chunk_size=...)API is unchanged; this PR only fixes the underlying execution.Identifier hash
chunk_sizeremains absent from__identifier_fields__(see feat(nss): chunk_size kwarg for inversion-heavy A100 likelihoods #1303 rationale): it's a memory-layout hint that does not affect the posterior, so runs with differentchunk_sizevalues share the same identifier and reuse cached results.🤖 Generated with Claude Code