Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions autogalaxy/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging:
The fit of the galaxies to the imaging dataset, which includes the log likelihood.
"""

if self._use_jax:
self._register_fit_imaging_pytrees()

galaxies = self.galaxies_via_instance_from(
instance=instance,
)
Expand All @@ -158,6 +161,48 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging:
xp=self._xp,
)

@staticmethod
def _register_fit_imaging_pytrees() -> None:
"""Register every type reachable from a ``FitImaging`` return value
so ``jax.jit(fit_from)`` can flatten its output.

``dataset``, ``adapt_images`` and ``settings`` are constants per
analysis — ride as aux so JAX does not recurse into them. Everything
else (``galaxies``, ``dataset_model`` and the autoarray wrappers they
carry) is dynamic per fit.
"""
from autoarray.abstract_ndarray import (
_pytree_registered_classes,
register_instance_pytree,
)
from autoarray.dataset.dataset_model import DatasetModel
from autoconf.jax_wrapper import register_pytree_node
from autogalaxy.galaxy.galaxies import Galaxies

register_instance_pytree(
FitImaging,
no_flatten=("dataset", "adapt_images", "settings"),
)
register_instance_pytree(DatasetModel)

# ``Galaxies`` is a ``list`` subclass — the generic ``__dict__`` flatten
# in ``register_instance_pytree`` would drop the list contents. Register
# a custom flatten that carries the list items as dynamic children.
if Galaxies not in _pytree_registered_classes:
def _flatten_galaxies(galaxies):
dict_items = tuple(sorted(galaxies.__dict__.items()))
return tuple(galaxies), dict_items

def _unflatten_galaxies(aux, children):
new = Galaxies.__new__(Galaxies)
list.__init__(new, children)
for key, value in aux:
setattr(new, key, value)
return new

register_pytree_node(Galaxies, _flatten_galaxies, _unflatten_galaxies)
_pytree_registered_classes.add(Galaxies)

def save_attributes(self, paths: af.DirectoryPaths):
"""
Before the non-linear search begins, this routine saves attributes of the `Analysis` object to the `files`
Expand Down
15 changes: 15 additions & 0 deletions test_autogalaxy/imaging/model/test_analysis_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,18 @@ def test__figure_of_merit__matches_correct_fit_given_galaxy_profiles(
fit = ag.FitImaging(dataset=masked_imaging_7x7, galaxies=galaxies)

assert fit.log_likelihood == fit_figure_of_merit


def test__register_fit_imaging_pytrees__registers_fit_galaxies_and_dataset_model():
from autoarray.abstract_ndarray import _pytree_registered_classes
from autoarray.dataset.dataset_model import DatasetModel
from autogalaxy.galaxy.galaxies import Galaxies
from autogalaxy.imaging.fit_imaging import FitImaging

ag.AnalysisImaging._register_fit_imaging_pytrees()

assert FitImaging in _pytree_registered_classes
assert DatasetModel in _pytree_registered_classes
assert Galaxies in _pytree_registered_classes


Loading