diff --git a/autogalaxy/imaging/model/analysis.py b/autogalaxy/imaging/model/analysis.py index 2d7aa011..e2bb1e96 100644 --- a/autogalaxy/imaging/model/analysis.py +++ b/autogalaxy/imaging/model/analysis.py @@ -27,6 +27,9 @@ from autogalaxy.imaging.fit_imaging import FitImaging +_FIT_IMAGING_PYTREES_REGISTERED = False + + class AnalysisImaging(AnalysisDataset): Result = ResultImaging Visualizer = VisualizerImaging @@ -209,11 +212,38 @@ def _register_fit_imaging_pytrees() -> None: analysis — ride as aux so JAX does not recurse into them. Everything else (``galaxies``, ``dataset_model`` and the autoarray wrappers they carry) is dynamic per fit. + + Idempotent — guarded by the module-level + ``_FIT_IMAGING_PYTREES_REGISTERED`` flag so repeated calls from each + ``fit_from`` invocation are cheap. ``DatasetModel`` may already be + registered by ``autofit.jax.pytrees.register_model`` (its + ``_REGISTERED_INSTANCE_CLASSES`` set is independent of autoarray's + ``_pytree_registered_classes``); cross-populate so + ``register_instance_pytree`` short-circuits instead of asking JAX to + register it twice. Mirrors the defense in + ``autogalaxy/ellipse/model/analysis.py``. """ - from autoarray.abstract_ndarray import register_instance_pytree + global _FIT_IMAGING_PYTREES_REGISTERED + if _FIT_IMAGING_PYTREES_REGISTERED: + return + + from autoarray.abstract_ndarray import ( + register_instance_pytree, + _pytree_registered_classes, + ) from autoarray.dataset.dataset_model import DatasetModel from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree + try: + from autofit.jax.pytrees import ( + _REGISTERED_INSTANCE_CLASSES as _af_registered, + ) + except ImportError: + _af_registered = set() + + if DatasetModel in _af_registered: + _pytree_registered_classes.add(DatasetModel) + register_instance_pytree( FitImaging, no_flatten=("dataset", "adapt_images", "settings"), @@ -221,6 +251,8 @@ def _register_fit_imaging_pytrees() -> None: register_instance_pytree(DatasetModel) register_galaxies_pytree() + _FIT_IMAGING_PYTREES_REGISTERED = True + def save_attributes(self, paths: af.DirectoryPaths): """ Before the non-linear search begins, this routine saves attributes of the `Analysis` object to the `files` diff --git a/autogalaxy/interferometer/model/analysis.py b/autogalaxy/interferometer/model/analysis.py index a59902cc..b67d6d90 100644 --- a/autogalaxy/interferometer/model/analysis.py +++ b/autogalaxy/interferometer/model/analysis.py @@ -30,6 +30,9 @@ logger.setLevel(level="INFO") +_FIT_INTERFEROMETER_PYTREES_REGISTERED = False + + class AnalysisInterferometer(AnalysisDataset): Result = ResultInterferometer Visualizer = VisualizerInterferometer @@ -173,11 +176,33 @@ def _register_fit_interferometer_pytrees() -> None: analysis — ride as aux so JAX does not recurse into them. Everything else (``galaxies`` and the autoarray wrappers it carries) is dynamic per fit. + + Idempotent — guarded by the module-level + ``_FIT_INTERFEROMETER_PYTREES_REGISTERED`` flag. See + ``autogalaxy/imaging/model/analysis.py`` for the cross-registration + rationale. """ - from autoarray.abstract_ndarray import register_instance_pytree + global _FIT_INTERFEROMETER_PYTREES_REGISTERED + if _FIT_INTERFEROMETER_PYTREES_REGISTERED: + return + + from autoarray.abstract_ndarray import ( + register_instance_pytree, + _pytree_registered_classes, + ) from autoarray.dataset.dataset_model import DatasetModel from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree + try: + from autofit.jax.pytrees import ( + _REGISTERED_INSTANCE_CLASSES as _af_registered, + ) + except ImportError: + _af_registered = set() + + if DatasetModel in _af_registered: + _pytree_registered_classes.add(DatasetModel) + register_instance_pytree( FitInterferometer, no_flatten=("dataset", "adapt_images", "settings"), @@ -185,6 +210,8 @@ def _register_fit_interferometer_pytrees() -> None: register_instance_pytree(DatasetModel) register_galaxies_pytree() + _FIT_INTERFEROMETER_PYTREES_REGISTERED = True + def save_attributes(self, paths: af.DirectoryPaths): """ Before the model-fit begins, this routine saves attributes of the `Analysis` object to the `files` folder