diff --git a/autolens/imaging/model/analysis.py b/autolens/imaging/model/analysis.py index d9ecda543..dcf249444 100644 --- a/autolens/imaging/model/analysis.py +++ b/autolens/imaging/model/analysis.py @@ -27,6 +27,9 @@ logger.setLevel(level="INFO") +_FIT_IMAGING_PYTREES_REGISTERED = False + + class AnalysisImaging(AnalysisDataset): Result = ResultImaging @@ -174,11 +177,38 @@ def _register_fit_imaging_pytrees() -> None: analysis — ride as aux so JAX does not recurse into them. Everything else (``tracer``, ``dataset_model`` and the autoarray wrappers they carry) is dynamic per fit. + + Idempotent — guarded by the module-level + ``_FIT_IMAGING_PYTREES_REGISTERED`` flag. ``DatasetModel`` and + ``Tracer`` may already be registered by + ``autofit.jax.pytrees.register_model`` (its + ``_REGISTERED_INSTANCE_CLASSES`` set is independent of autoarray's + ``_pytree_registered_classes``); cross-populate so + ``register_instance_pytree`` short-circuits. Mirrors the defense in + ``autogalaxy/ellipse/model/analysis.py``. """ - from autoarray.abstract_ndarray import register_instance_pytree + global _FIT_IMAGING_PYTREES_REGISTERED + if _FIT_IMAGING_PYTREES_REGISTERED: + return + + from autoarray.abstract_ndarray import ( + register_instance_pytree, + _pytree_registered_classes, + ) from autoarray.dataset.dataset_model import DatasetModel from autolens.lens.tracer import Tracer + try: + from autofit.jax.pytrees import ( + _REGISTERED_INSTANCE_CLASSES as _af_registered, + ) + except ImportError: + _af_registered = set() + + for cls in (DatasetModel, Tracer): + if cls in _af_registered: + _pytree_registered_classes.add(cls) + register_instance_pytree( FitImaging, no_flatten=("dataset", "adapt_images", "settings"), @@ -187,3 +217,5 @@ def _register_fit_imaging_pytrees() -> None: # ``cosmology`` is a fixed physical constant per fit; ride as aux. register_instance_pytree(Tracer, no_flatten=("cosmology",)) + _FIT_IMAGING_PYTREES_REGISTERED = True + diff --git a/autolens/interferometer/model/analysis.py b/autolens/interferometer/model/analysis.py index 1fae1afae..d2b8bb93a 100644 --- a/autolens/interferometer/model/analysis.py +++ b/autolens/interferometer/model/analysis.py @@ -33,6 +33,9 @@ logger.setLevel(level="INFO") +_FIT_INTERFEROMETER_PYTREES_REGISTERED = False + + class AnalysisInterferometer(AnalysisDataset): Result = ResultInterferometer Visualizer = VisualizerInterferometer @@ -203,11 +206,34 @@ def _register_fit_interferometer_pytrees() -> None: analysis — ride as aux so JAX does not recurse into them. Everything else (``tracer`` and the autoarray wrappers it carries) is dynamic per fit. + + Idempotent — guarded by the module-level + ``_FIT_INTERFEROMETER_PYTREES_REGISTERED`` flag. See + ``autolens/imaging/model/analysis.py`` for the cross-registration + rationale. """ - from autoarray.abstract_ndarray import register_instance_pytree + global _FIT_INTERFEROMETER_PYTREES_REGISTERED + if _FIT_INTERFEROMETER_PYTREES_REGISTERED: + return + + from autoarray.abstract_ndarray import ( + register_instance_pytree, + _pytree_registered_classes, + ) from autoarray.dataset.dataset_model import DatasetModel # fit-interferometer-pytree-mge from autolens.lens.tracer import Tracer + try: + from autofit.jax.pytrees import ( + _REGISTERED_INSTANCE_CLASSES as _af_registered, + ) + except ImportError: + _af_registered = set() + + for cls in (DatasetModel, Tracer): + if cls in _af_registered: + _pytree_registered_classes.add(cls) + register_instance_pytree( FitInterferometer, no_flatten=("dataset", "adapt_images", "settings"), @@ -215,6 +241,8 @@ def _register_fit_interferometer_pytrees() -> None: register_instance_pytree(Tracer, no_flatten=("cosmology",)) register_instance_pytree(DatasetModel) # fit-interferometer-pytree-mge + _FIT_INTERFEROMETER_PYTREES_REGISTERED = True + def save_attributes(self, paths: af.DirectoryPaths): """ Before the model-fit begins, this routine saves attributes of the `Analysis` object to the `files` folder