From 49910331c385364356be7b6d3ccc9d58b3417cff Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 22 Apr 2026 11:21:08 +0100 Subject: [PATCH 1/2] feat: register FitImaging, DatasetModel, Galaxies as pytrees in AnalysisImaging Adds `_register_fit_imaging_pytrees` staticmethod to `AnalysisImaging`, called from `fit_from` under the existing `use_jax` gate. Mirrors autolens's `AnalysisImaging._register_fit_imaging_pytrees`. Unblocks `jax.jit(analysis.fit_from)(instance)` returning a `FitImaging` with `jax.Array` leaves on the autogalaxy imaging path. Part of PyAutoLabs/autogalaxy_workspace_test#8 (epic #5, task 3/9). Co-Authored-By: Claude Opus 4.7 --- autogalaxy/imaging/model/analysis.py | 24 +++++++++++++++++++ .../imaging/model/test_analysis_imaging.py | 13 ++++++++++ 2 files changed, 37 insertions(+) diff --git a/autogalaxy/imaging/model/analysis.py b/autogalaxy/imaging/model/analysis.py index 1a2f0e46d..3743d71cf 100644 --- a/autogalaxy/imaging/model/analysis.py +++ b/autogalaxy/imaging/model/analysis.py @@ -141,6 +141,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging: The fit of the galaxies to the imaging dataset, which includes the log likelihood. """ + if self._use_jax: + self._register_fit_imaging_pytrees() + galaxies = self.galaxies_via_instance_from( instance=instance, ) @@ -158,6 +161,27 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging: xp=self._xp, ) + @staticmethod + def _register_fit_imaging_pytrees() -> None: + """Register every type reachable from a ``FitImaging`` return value + so ``jax.jit(fit_from)`` can flatten its output. + + ``dataset``, ``adapt_images`` and ``settings`` are constants per + analysis — ride as aux so JAX does not recurse into them. Everything + else (``galaxies``, ``dataset_model`` and the autoarray wrappers they + carry) is dynamic per fit. + """ + from autoarray.abstract_ndarray import register_instance_pytree + from autoarray.dataset.dataset_model import DatasetModel + from autogalaxy.galaxy.galaxies import Galaxies + + register_instance_pytree( + FitImaging, + no_flatten=("dataset", "adapt_images", "settings"), + ) + register_instance_pytree(DatasetModel) + register_instance_pytree(Galaxies) + def save_attributes(self, paths: af.DirectoryPaths): """ Before the non-linear search begins, this routine saves attributes of the `Analysis` object to the `files` diff --git a/test_autogalaxy/imaging/model/test_analysis_imaging.py b/test_autogalaxy/imaging/model/test_analysis_imaging.py index d65b64a37..9a6823bdf 100644 --- a/test_autogalaxy/imaging/model/test_analysis_imaging.py +++ b/test_autogalaxy/imaging/model/test_analysis_imaging.py @@ -37,3 +37,16 @@ def test__figure_of_merit__matches_correct_fit_given_galaxy_profiles( fit = ag.FitImaging(dataset=masked_imaging_7x7, galaxies=galaxies) assert fit.log_likelihood == fit_figure_of_merit + + +def test__register_fit_imaging_pytrees__registers_fit_galaxies_and_dataset_model(): + from autoarray.abstract_ndarray import _pytree_registered_classes + from autoarray.dataset.dataset_model import DatasetModel + from autogalaxy.galaxy.galaxies import Galaxies + from autogalaxy.imaging.fit_imaging import FitImaging + + ag.AnalysisImaging._register_fit_imaging_pytrees() + + assert FitImaging in _pytree_registered_classes + assert DatasetModel in _pytree_registered_classes + assert Galaxies in _pytree_registered_classes From 28b248b8fbd74416762027497c5ea75d1a2b205b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 22 Apr 2026 11:52:29 +0100 Subject: [PATCH 2/2] fix: register Galaxies list subclass with a custom pytree flatten `Galaxies` inherits from `list`, so the generic `__dict__`-based flatten in `register_instance_pytree` dropped the list contents across the JIT boundary. Register `Galaxies` with a custom flatten/unflatten that carries list items as dynamic children and any `__dict__` state as aux. Without this, `jax.jit(analysis.fit_from)(instance)` round-tripped but the resulting `FitImaging.galaxies` was an empty `Galaxies`, causing `galaxies.image_2d_from` to return `sum([])` (int 0) and fail in autoarray's `Array2D` constructor on `fit.log_likelihood` access. Co-Authored-By: Claude Opus 4.7 --- autogalaxy/imaging/model/analysis.py | 25 +++++++++++++++++-- .../imaging/model/test_analysis_imaging.py | 2 ++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/autogalaxy/imaging/model/analysis.py b/autogalaxy/imaging/model/analysis.py index 3743d71cf..e2ea6fc36 100644 --- a/autogalaxy/imaging/model/analysis.py +++ b/autogalaxy/imaging/model/analysis.py @@ -171,8 +171,12 @@ def _register_fit_imaging_pytrees() -> None: else (``galaxies``, ``dataset_model`` and the autoarray wrappers they carry) is dynamic per fit. """ - from autoarray.abstract_ndarray import register_instance_pytree + from autoarray.abstract_ndarray import ( + _pytree_registered_classes, + register_instance_pytree, + ) from autoarray.dataset.dataset_model import DatasetModel + from autoconf.jax_wrapper import register_pytree_node from autogalaxy.galaxy.galaxies import Galaxies register_instance_pytree( @@ -180,7 +184,24 @@ def _register_fit_imaging_pytrees() -> None: no_flatten=("dataset", "adapt_images", "settings"), ) register_instance_pytree(DatasetModel) - register_instance_pytree(Galaxies) + + # ``Galaxies`` is a ``list`` subclass — the generic ``__dict__`` flatten + # in ``register_instance_pytree`` would drop the list contents. Register + # a custom flatten that carries the list items as dynamic children. + if Galaxies not in _pytree_registered_classes: + def _flatten_galaxies(galaxies): + dict_items = tuple(sorted(galaxies.__dict__.items())) + return tuple(galaxies), dict_items + + def _unflatten_galaxies(aux, children): + new = Galaxies.__new__(Galaxies) + list.__init__(new, children) + for key, value in aux: + setattr(new, key, value) + return new + + register_pytree_node(Galaxies, _flatten_galaxies, _unflatten_galaxies) + _pytree_registered_classes.add(Galaxies) def save_attributes(self, paths: af.DirectoryPaths): """ diff --git a/test_autogalaxy/imaging/model/test_analysis_imaging.py b/test_autogalaxy/imaging/model/test_analysis_imaging.py index 9a6823bdf..760b037d9 100644 --- a/test_autogalaxy/imaging/model/test_analysis_imaging.py +++ b/test_autogalaxy/imaging/model/test_analysis_imaging.py @@ -50,3 +50,5 @@ def test__register_fit_imaging_pytrees__registers_fit_galaxies_and_dataset_model assert FitImaging in _pytree_registered_classes assert DatasetModel in _pytree_registered_classes assert Galaxies in _pytree_registered_classes + +