From 69b1e0992cd6729b84c2c91632037de469c2df0f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 29 May 2026 10:36:57 +0100 Subject: [PATCH] fix(jax): defensive pytree dedup in imaging/interferometer analyses MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `AnalysisImaging._register_fit_imaging_pytrees` and the interferometer counterpart re-registered ``DatasetModel`` on every ``fit_from`` call even though autoarray's ``_pytree_registered_classes`` is supposed to make ``register_instance_pytree`` idempotent. The breakdown surfaced in the 2026.5.29.2 release on `multi/start_here.py`: ValueError: Duplicate custom PyTreeDef type registration for . Root cause: ``autofit.jax.pytrees.register_model`` walks the model and calls ``register_pytree_node`` for any class it finds (including ``DatasetModel`` when present in the model). It dedupes via its own ``_REGISTERED_INSTANCE_CLASSES`` set, **independent** of autoarray's ``_pytree_registered_classes``. Subsequent ``register_instance_pytree(DatasetModel)`` doesn't find it in autoarray's set, asks JAX to register again, JAX rejects. Mirrors the existing defense in ``autogalaxy/ellipse/model/analysis.py``: * Module-level ``_FIT_*_PYTREES_REGISTERED`` flag — skip re-entry. * Cross-populate autoarray's set from autofit's so ``register_instance_pytree`` short-circuits on classes autofit already handled. Co-Authored-By: Claude Opus 4.7 --- autogalaxy/imaging/model/analysis.py | 34 ++++++++++++++++++++- autogalaxy/interferometer/model/analysis.py | 29 +++++++++++++++++- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/autogalaxy/imaging/model/analysis.py b/autogalaxy/imaging/model/analysis.py index 2d7aa011..e2bb1e96 100644 --- a/autogalaxy/imaging/model/analysis.py +++ b/autogalaxy/imaging/model/analysis.py @@ -27,6 +27,9 @@ from autogalaxy.imaging.fit_imaging import FitImaging +_FIT_IMAGING_PYTREES_REGISTERED = False + + class AnalysisImaging(AnalysisDataset): Result = ResultImaging Visualizer = VisualizerImaging @@ -209,11 +212,38 @@ def _register_fit_imaging_pytrees() -> None: analysis — ride as aux so JAX does not recurse into them. Everything else (``galaxies``, ``dataset_model`` and the autoarray wrappers they carry) is dynamic per fit. + + Idempotent — guarded by the module-level + ``_FIT_IMAGING_PYTREES_REGISTERED`` flag so repeated calls from each + ``fit_from`` invocation are cheap. ``DatasetModel`` may already be + registered by ``autofit.jax.pytrees.register_model`` (its + ``_REGISTERED_INSTANCE_CLASSES`` set is independent of autoarray's + ``_pytree_registered_classes``); cross-populate so + ``register_instance_pytree`` short-circuits instead of asking JAX to + register it twice. Mirrors the defense in + ``autogalaxy/ellipse/model/analysis.py``. """ - from autoarray.abstract_ndarray import register_instance_pytree + global _FIT_IMAGING_PYTREES_REGISTERED + if _FIT_IMAGING_PYTREES_REGISTERED: + return + + from autoarray.abstract_ndarray import ( + register_instance_pytree, + _pytree_registered_classes, + ) from autoarray.dataset.dataset_model import DatasetModel from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree + try: + from autofit.jax.pytrees import ( + _REGISTERED_INSTANCE_CLASSES as _af_registered, + ) + except ImportError: + _af_registered = set() + + if DatasetModel in _af_registered: + _pytree_registered_classes.add(DatasetModel) + register_instance_pytree( FitImaging, no_flatten=("dataset", "adapt_images", "settings"), @@ -221,6 +251,8 @@ def _register_fit_imaging_pytrees() -> None: register_instance_pytree(DatasetModel) register_galaxies_pytree() + _FIT_IMAGING_PYTREES_REGISTERED = True + def save_attributes(self, paths: af.DirectoryPaths): """ Before the non-linear search begins, this routine saves attributes of the `Analysis` object to the `files` diff --git a/autogalaxy/interferometer/model/analysis.py b/autogalaxy/interferometer/model/analysis.py index a59902cc..b67d6d90 100644 --- a/autogalaxy/interferometer/model/analysis.py +++ b/autogalaxy/interferometer/model/analysis.py @@ -30,6 +30,9 @@ logger.setLevel(level="INFO") +_FIT_INTERFEROMETER_PYTREES_REGISTERED = False + + class AnalysisInterferometer(AnalysisDataset): Result = ResultInterferometer Visualizer = VisualizerInterferometer @@ -173,11 +176,33 @@ def _register_fit_interferometer_pytrees() -> None: analysis — ride as aux so JAX does not recurse into them. Everything else (``galaxies`` and the autoarray wrappers it carries) is dynamic per fit. + + Idempotent — guarded by the module-level + ``_FIT_INTERFEROMETER_PYTREES_REGISTERED`` flag. See + ``autogalaxy/imaging/model/analysis.py`` for the cross-registration + rationale. """ - from autoarray.abstract_ndarray import register_instance_pytree + global _FIT_INTERFEROMETER_PYTREES_REGISTERED + if _FIT_INTERFEROMETER_PYTREES_REGISTERED: + return + + from autoarray.abstract_ndarray import ( + register_instance_pytree, + _pytree_registered_classes, + ) from autoarray.dataset.dataset_model import DatasetModel from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree + try: + from autofit.jax.pytrees import ( + _REGISTERED_INSTANCE_CLASSES as _af_registered, + ) + except ImportError: + _af_registered = set() + + if DatasetModel in _af_registered: + _pytree_registered_classes.add(DatasetModel) + register_instance_pytree( FitInterferometer, no_flatten=("dataset", "adapt_images", "settings"), @@ -185,6 +210,8 @@ def _register_fit_interferometer_pytrees() -> None: register_instance_pytree(DatasetModel) register_galaxies_pytree() + _FIT_INTERFEROMETER_PYTREES_REGISTERED = True + def save_attributes(self, paths: af.DirectoryPaths): """ Before the model-fit begins, this routine saves attributes of the `Analysis` object to the `files` folder