From e32ab5b7cc402b5615e75d26fc26e4171145c8f6 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 30 Apr 2026 19:55:26 +0100 Subject: [PATCH] feat(dynesty): support JAX-jitted likelihoods via use_jax_jit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit dynesty 2.1.5's NestedSampler has no `vectorized` parameter — it calls the likelihood one sample at a time, so Nautilus's vmap-batching approach doesn't apply. Use `jax.jit` on its own instead: JAX's compiled-function cache reuses the compiled likelihood across calls, giving a fast CPU/GPU evaluation path for nested sampling without requiring autodiff. Changes: - `Fitness.__init__` accepts `use_jax_jit: bool = False`. When set, `self._call = self._jit` (parallel to the existing `_vmap` dispatch). vmap takes precedence if both flags are somehow set. - `Fitness.call_wrap` casts the jit-path return value to a Python `float`. dynesty's `logz` accumulators and HDF5 savestate require numpy/Python scalars, not raw JAX `Array`s. The vmap path is untouched — Nautilus accepts JAX arrays at its `vectorized=True` interface. - `Fitness.__getstate__` / `__setstate__` re-enabled (previously commented out) and extended to strip `_call`, `_jit`, `_vmap`, `_grad` from the pickle. dynesty's `run_nested(checkpoint_file=...)` pickles the loglikelihood; JAX-compiled callables hold C++ XLA state that doesn't roundtrip through pickle. `__setstate__` re-derives the dispatch on resume so the cached_property recompiles lazily on the first call. - `AbstractDynesty.__init__` accepts `use_jax_jit: bool = True`. In `_fit`, the upfront `Fitness(...)` construction passes `use_jax_jit=(analysis._use_jax and self.use_jax_jit)`. Default-on when JAX is enabled; user can disable via the search-class flag. - The existing no-pool fallback (triggered when `force_x1_cpu` or `analysis._use_jax`) now branches the log message three ways: JAX path, force_x1_cpu, OS-multiprocessing fallback. The original message wrongly attributed JAX/force_x1_cpu fallbacks to "OS does not support multiprocessing". Tests: 5 new unit tests in `test_fitness_jax_dispatch.py` cover the dispatch logic and pickle round-trip. They do not import jax (per project policy: library unit tests stay numpy-only). `test_dict` fixture updated for the new `use_jax_jit` arg on `DynestyStatic`. Verification: companion script `Dynesty_jax.py` will land in `autofit_workspace_test`; runs end-to-end with `log_Z ≈ -54, dlogz < 0.5` on the standard 1D Gaussian dataset. `Nautilus_jax.py` remains green (vmap path unchanged). Co-Authored-By: Claude Opus 4.7 (1M context) --- autofit/non_linear/fitness.py | 32 ++++++--- .../search/nest/dynesty/search/abstract.py | 26 +++++-- test_autofit/non_linear/test_dict.py | 1 + .../non_linear/test_fitness_jax_dispatch.py | 67 +++++++++++++++++++ 4 files changed, 110 insertions(+), 16 deletions(-) create mode 100644 test_autofit/non_linear/test_fitness_jax_dispatch.py diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index f113651fa..d270c7885 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -42,6 +42,7 @@ def __init__( convert_to_chi_squared: bool = False, store_history: bool = False, use_jax_vmap : bool = False, + use_jax_jit : bool = False, batch_size : Optional[int] = None, iterations_per_quick_update: Optional[int] = None, background_quick_update: bool = False, @@ -118,11 +119,14 @@ def __init__( self.log_likelihood_history_list = [] self.use_jax_vmap = use_jax_vmap + self.use_jax_jit = use_jax_jit self._call = self.call if self.use_jax_vmap: self._call = self._vmap + elif self.use_jax_jit: + self._call = self._jit self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update @@ -235,6 +239,9 @@ def call_wrap(self, parameters): figure_of_merit = self._call(parameters) + if self.use_jax_jit: + figure_of_merit = float(figure_of_merit) + if self.convert_to_chi_squared: log_likelihood = -0.5 * figure_of_merit else: @@ -382,15 +389,22 @@ def __call__(self, parameters, *kwargs): """ return self.call_wrap(parameters) - # def __getstate__(self): - # state = self.__dict__.copy() - # # Remove non-pickleable attributes - # state.pop('_call', None) - # state.pop('_grad', None) - # return state - # - # def __setstate__(self, state): - # self.__dict__.update(state) + def __getstate__(self): + state = self.__dict__.copy() + # Strip JAX-compiled callables: jax.jit / jax.vmap / jax.grad return + # functions tied to C++ XLA state that can't roundtrip through pickle. + # cached_property values lazily recompile on first access after unpickle. + for attr in ("_call", "_jit", "_vmap", "_grad"): + state.pop(attr, None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._call = self.call + if getattr(self, "use_jax_vmap", False): + self._call = self._vmap + elif getattr(self, "use_jax_jit", False): + self._call = self._jit @cached_property def _vmap(self): diff --git a/autofit/non_linear/search/nest/dynesty/search/abstract.py b/autofit/non_linear/search/nest/dynesty/search/abstract.py index de8d27e5e..bc53febed 100644 --- a/autofit/non_linear/search/nest/dynesty/search/abstract.py +++ b/autofit/non_linear/search/nest/dynesty/search/abstract.py @@ -53,6 +53,7 @@ def __init__( number_of_cores: int = 1, silence: bool = False, force_x1_cpu: bool = False, + use_jax_jit: bool = True, session: Optional[sa.orm.Session] = None, **kwargs, ): @@ -117,6 +118,7 @@ def __init__( self.maxcall = maxcall self.force_x1_cpu = force_x1_cpu + self.use_jax_jit = use_jax_jit self.logger.debug(f"Creating {self.__class__.__name__} Search") @@ -179,6 +181,7 @@ def _fit( paths=self.paths, fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99, + use_jax_jit=getattr(analysis, "_use_jax", False) and self.use_jax_jit, ) if not isinstance(self.paths, NullPaths): @@ -225,13 +228,22 @@ def _fit( except RuntimeError: if not checkpoint_exists: - self.logger.info( - """ - Your operating system does not support Python multiprocessing. - - A single CPU non-multiprocessing Dynesty run is being performed. - """ - ) + if getattr(analysis, "_use_jax", False): + self.logger.info( + "Running Dynesty with JAX-jitted likelihood (single CPU, no pool)." + ) + elif self.force_x1_cpu: + self.logger.info( + "Running Dynesty single-CPU per `force_x1_cpu=True` (no pool)." + ) + else: + self.logger.info( + """ + Your operating system does not support Python multiprocessing. + + A single CPU non-multiprocessing Dynesty run is being performed. + """ + ) search_internal = self.search_internal_from( model=model, diff --git a/test_autofit/non_linear/test_dict.py b/test_autofit/non_linear/test_dict.py index 5a8034685..04eeb044c 100644 --- a/test_autofit/non_linear/test_dict.py +++ b/test_autofit/non_linear/test_dict.py @@ -43,6 +43,7 @@ def make_dynesty_dict(): "slices": 5, "unique_tag": None, "update_interval": None, + "use_jax_jit": True, "walks": 5, }, "class_path": "autofit.non_linear.search.nest.dynesty.search.static.DynestyStatic", diff --git a/test_autofit/non_linear/test_fitness_jax_dispatch.py b/test_autofit/non_linear/test_fitness_jax_dispatch.py new file mode 100644 index 000000000..23243c10c --- /dev/null +++ b/test_autofit/non_linear/test_fitness_jax_dispatch.py @@ -0,0 +1,67 @@ +import pickle + +import numpy as np + +import autofit as af +from autofit.non_linear.fitness import Fitness + + +def _make_fitness(**kwargs): + model = af.Model(af.ex.Gaussian) + data = np.ones(20) + noise_map = np.ones(20) * 0.1 + analysis = af.ex.Analysis(data=data, noise_map=noise_map) + return Fitness(model=model, analysis=analysis, **kwargs) + + +def test_default_dispatch_is_call(): + fitness = _make_fitness() + # `self.call` produces a fresh bound method each access — compare the + # underlying function instead of the bound-method instance. + assert fitness._call.__func__ is Fitness.call + assert fitness.use_jax_jit is False + assert fitness.use_jax_vmap is False + + +def test_jit_dispatch_sets_call_to_jit(): + fitness = _make_fitness(use_jax_jit=True) + assert fitness.use_jax_jit is True + assert fitness._call is fitness._jit + + +def test_vmap_takes_precedence_over_jit(): + fitness = _make_fitness(use_jax_jit=True, use_jax_vmap=True) + assert fitness._call is fitness._vmap + + +def test_pickle_strips_jax_cached_attrs(): + """ + Dynesty's checkpoint writes pickle the loglikelihood. JAX-compiled + callables (jax.jit / jax.vmap / jax.grad) carry C++ XLA state that + cannot roundtrip through pickle. ``Fitness.__getstate__`` must drop + them; ``Fitness.__setstate__`` re-derives the dispatch on resume. + """ + fitness = _make_fitness(use_jax_jit=True) + + state = fitness.__getstate__() + assert "_call" not in state + assert "_jit" not in state + assert "_vmap" not in state + assert "_grad" not in state + + blob = pickle.dumps(fitness) + restored = pickle.loads(blob) + + assert restored.use_jax_jit is True + assert restored._call is restored._jit + + +def test_pickle_default_path_unchanged(): + fitness = _make_fitness() + + blob = pickle.dumps(fitness) + restored = pickle.loads(blob) + + assert restored.use_jax_jit is False + assert restored.use_jax_vmap is False + assert restored._call.__func__ is Fitness.call