Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion autogalaxy/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from autogalaxy.imaging.fit_imaging import FitImaging


_FIT_IMAGING_PYTREES_REGISTERED = False


class AnalysisImaging(AnalysisDataset):
Result = ResultImaging
Visualizer = VisualizerImaging
Expand Down Expand Up @@ -209,18 +212,47 @@ def _register_fit_imaging_pytrees() -> None:
analysis — ride as aux so JAX does not recurse into them. Everything
else (``galaxies``, ``dataset_model`` and the autoarray wrappers they
carry) is dynamic per fit.

Idempotent — guarded by the module-level
``_FIT_IMAGING_PYTREES_REGISTERED`` flag so repeated calls from each
``fit_from`` invocation are cheap. ``DatasetModel`` may already be
registered by ``autofit.jax.pytrees.register_model`` (its
``_REGISTERED_INSTANCE_CLASSES`` set is independent of autoarray's
``_pytree_registered_classes``); cross-populate so
``register_instance_pytree`` short-circuits instead of asking JAX to
register it twice. Mirrors the defense in
``autogalaxy/ellipse/model/analysis.py``.
"""
from autoarray.abstract_ndarray import register_instance_pytree
global _FIT_IMAGING_PYTREES_REGISTERED
if _FIT_IMAGING_PYTREES_REGISTERED:
return

from autoarray.abstract_ndarray import (
register_instance_pytree,
_pytree_registered_classes,
)
from autoarray.dataset.dataset_model import DatasetModel
from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree

try:
from autofit.jax.pytrees import (
_REGISTERED_INSTANCE_CLASSES as _af_registered,
)
except ImportError:
_af_registered = set()

if DatasetModel in _af_registered:
_pytree_registered_classes.add(DatasetModel)

register_instance_pytree(
FitImaging,
no_flatten=("dataset", "adapt_images", "settings"),
)
register_instance_pytree(DatasetModel)
register_galaxies_pytree()

_FIT_IMAGING_PYTREES_REGISTERED = True

def save_attributes(self, paths: af.DirectoryPaths):
"""
Before the non-linear search begins, this routine saves attributes of the `Analysis` object to the `files`
Expand Down
29 changes: 28 additions & 1 deletion autogalaxy/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
logger.setLevel(level="INFO")


_FIT_INTERFEROMETER_PYTREES_REGISTERED = False


class AnalysisInterferometer(AnalysisDataset):
Result = ResultInterferometer
Visualizer = VisualizerInterferometer
Expand Down Expand Up @@ -173,18 +176,42 @@ def _register_fit_interferometer_pytrees() -> None:
analysis — ride as aux so JAX does not recurse into them. Everything
else (``galaxies`` and the autoarray wrappers it carries) is dynamic
per fit.

Idempotent — guarded by the module-level
``_FIT_INTERFEROMETER_PYTREES_REGISTERED`` flag. See
``autogalaxy/imaging/model/analysis.py`` for the cross-registration
rationale.
"""
from autoarray.abstract_ndarray import register_instance_pytree
global _FIT_INTERFEROMETER_PYTREES_REGISTERED
if _FIT_INTERFEROMETER_PYTREES_REGISTERED:
return

from autoarray.abstract_ndarray import (
register_instance_pytree,
_pytree_registered_classes,
)
from autoarray.dataset.dataset_model import DatasetModel
from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree

try:
from autofit.jax.pytrees import (
_REGISTERED_INSTANCE_CLASSES as _af_registered,
)
except ImportError:
_af_registered = set()

if DatasetModel in _af_registered:
_pytree_registered_classes.add(DatasetModel)

register_instance_pytree(
FitInterferometer,
no_flatten=("dataset", "adapt_images", "settings"),
)
register_instance_pytree(DatasetModel)
register_galaxies_pytree()

_FIT_INTERFEROMETER_PYTREES_REGISTERED = True

def save_attributes(self, paths: af.DirectoryPaths):
"""
Before the model-fit begins, this routine saves attributes of the `Analysis` object to the `files` folder
Expand Down
Loading