Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion autolens/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
logger.setLevel(level="INFO")


_FIT_IMAGING_PYTREES_REGISTERED = False


class AnalysisImaging(AnalysisDataset):

Result = ResultImaging
Expand Down Expand Up @@ -174,11 +177,38 @@ def _register_fit_imaging_pytrees() -> None:
analysis — ride as aux so JAX does not recurse into them. Everything
else (``tracer``, ``dataset_model`` and the autoarray wrappers they
carry) is dynamic per fit.

Idempotent — guarded by the module-level
``_FIT_IMAGING_PYTREES_REGISTERED`` flag. ``DatasetModel`` and
``Tracer`` may already be registered by
``autofit.jax.pytrees.register_model`` (its
``_REGISTERED_INSTANCE_CLASSES`` set is independent of autoarray's
``_pytree_registered_classes``); cross-populate so
``register_instance_pytree`` short-circuits. Mirrors the defense in
``autogalaxy/ellipse/model/analysis.py``.
"""
from autoarray.abstract_ndarray import register_instance_pytree
global _FIT_IMAGING_PYTREES_REGISTERED
if _FIT_IMAGING_PYTREES_REGISTERED:
return

from autoarray.abstract_ndarray import (
register_instance_pytree,
_pytree_registered_classes,
)
from autoarray.dataset.dataset_model import DatasetModel
from autolens.lens.tracer import Tracer

try:
from autofit.jax.pytrees import (
_REGISTERED_INSTANCE_CLASSES as _af_registered,
)
except ImportError:
_af_registered = set()

for cls in (DatasetModel, Tracer):
if cls in _af_registered:
_pytree_registered_classes.add(cls)

register_instance_pytree(
FitImaging,
no_flatten=("dataset", "adapt_images", "settings"),
Expand All @@ -187,3 +217,5 @@ def _register_fit_imaging_pytrees() -> None:
# ``cosmology`` is a fixed physical constant per fit; ride as aux.
register_instance_pytree(Tracer, no_flatten=("cosmology",))

_FIT_IMAGING_PYTREES_REGISTERED = True

30 changes: 29 additions & 1 deletion autolens/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
logger.setLevel(level="INFO")


_FIT_INTERFEROMETER_PYTREES_REGISTERED = False


class AnalysisInterferometer(AnalysisDataset):
Result = ResultInterferometer
Visualizer = VisualizerInterferometer
Expand Down Expand Up @@ -203,18 +206,43 @@ def _register_fit_interferometer_pytrees() -> None:
analysis — ride as aux so JAX does not recurse into them. Everything
else (``tracer`` and the autoarray wrappers it carries) is dynamic
per fit.

Idempotent — guarded by the module-level
``_FIT_INTERFEROMETER_PYTREES_REGISTERED`` flag. See
``autolens/imaging/model/analysis.py`` for the cross-registration
rationale.
"""
from autoarray.abstract_ndarray import register_instance_pytree
global _FIT_INTERFEROMETER_PYTREES_REGISTERED
if _FIT_INTERFEROMETER_PYTREES_REGISTERED:
return

from autoarray.abstract_ndarray import (
register_instance_pytree,
_pytree_registered_classes,
)
from autoarray.dataset.dataset_model import DatasetModel # fit-interferometer-pytree-mge
from autolens.lens.tracer import Tracer

try:
from autofit.jax.pytrees import (
_REGISTERED_INSTANCE_CLASSES as _af_registered,
)
except ImportError:
_af_registered = set()

for cls in (DatasetModel, Tracer):
if cls in _af_registered:
_pytree_registered_classes.add(cls)

register_instance_pytree(
FitInterferometer,
no_flatten=("dataset", "adapt_images", "settings"),
)
register_instance_pytree(Tracer, no_flatten=("cosmology",))
register_instance_pytree(DatasetModel) # fit-interferometer-pytree-mge

_FIT_INTERFEROMETER_PYTREES_REGISTERED = True

def save_attributes(self, paths: af.DirectoryPaths):
"""
Before the model-fit begins, this routine saves attributes of the `Analysis` object to the `files` folder
Expand Down
Loading