diff --git a/autogalaxy/imaging/model/analysis.py b/autogalaxy/imaging/model/analysis.py index 1a2f0e46..e2ea6fc3 100644 --- a/autogalaxy/imaging/model/analysis.py +++ b/autogalaxy/imaging/model/analysis.py @@ -141,6 +141,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging: The fit of the galaxies to the imaging dataset, which includes the log likelihood. """ + if self._use_jax: + self._register_fit_imaging_pytrees() + galaxies = self.galaxies_via_instance_from( instance=instance, ) @@ -158,6 +161,48 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging: xp=self._xp, ) + @staticmethod + def _register_fit_imaging_pytrees() -> None: + """Register every type reachable from a ``FitImaging`` return value + so ``jax.jit(fit_from)`` can flatten its output. + + ``dataset``, ``adapt_images`` and ``settings`` are constants per + analysis — ride as aux so JAX does not recurse into them. Everything + else (``galaxies``, ``dataset_model`` and the autoarray wrappers they + carry) is dynamic per fit. + """ + from autoarray.abstract_ndarray import ( + _pytree_registered_classes, + register_instance_pytree, + ) + from autoarray.dataset.dataset_model import DatasetModel + from autoconf.jax_wrapper import register_pytree_node + from autogalaxy.galaxy.galaxies import Galaxies + + register_instance_pytree( + FitImaging, + no_flatten=("dataset", "adapt_images", "settings"), + ) + register_instance_pytree(DatasetModel) + + # ``Galaxies`` is a ``list`` subclass — the generic ``__dict__`` flatten + # in ``register_instance_pytree`` would drop the list contents. Register + # a custom flatten that carries the list items as dynamic children. + if Galaxies not in _pytree_registered_classes: + def _flatten_galaxies(galaxies): + dict_items = tuple(sorted(galaxies.__dict__.items())) + return tuple(galaxies), dict_items + + def _unflatten_galaxies(aux, children): + new = Galaxies.__new__(Galaxies) + list.__init__(new, children) + for key, value in aux: + setattr(new, key, value) + return new + + register_pytree_node(Galaxies, _flatten_galaxies, _unflatten_galaxies) + _pytree_registered_classes.add(Galaxies) + def save_attributes(self, paths: af.DirectoryPaths): """ Before the non-linear search begins, this routine saves attributes of the `Analysis` object to the `files` diff --git a/test_autogalaxy/imaging/model/test_analysis_imaging.py b/test_autogalaxy/imaging/model/test_analysis_imaging.py index d65b64a3..760b037d 100644 --- a/test_autogalaxy/imaging/model/test_analysis_imaging.py +++ b/test_autogalaxy/imaging/model/test_analysis_imaging.py @@ -37,3 +37,18 @@ def test__figure_of_merit__matches_correct_fit_given_galaxy_profiles( fit = ag.FitImaging(dataset=masked_imaging_7x7, galaxies=galaxies) assert fit.log_likelihood == fit_figure_of_merit + + +def test__register_fit_imaging_pytrees__registers_fit_galaxies_and_dataset_model(): + from autoarray.abstract_ndarray import _pytree_registered_classes + from autoarray.dataset.dataset_model import DatasetModel + from autogalaxy.galaxy.galaxies import Galaxies + from autogalaxy.imaging.fit_imaging import FitImaging + + ag.AnalysisImaging._register_fit_imaging_pytrees() + + assert FitImaging in _pytree_registered_classes + assert DatasetModel in _pytree_registered_classes + assert Galaxies in _pytree_registered_classes + +