diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index 0f3e16dc9..ba9d0ea99 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1,5 +1,4 @@ import copy -import functools import inspect import json import logging @@ -1860,12 +1859,26 @@ def order_no(self) -> str: ] return ":".join(values) - @functools.cached_property + @property def parameterization(self) -> str: """ Describes the path to each of the PriorModels, its class - and its number of free parameters - """ + and its number of free parameters. + + Cached on first access in ``self.__dict__`` under the + ``_`` -prefixed key ``_parameterization_cache`` so that + ``Collection._instance_for_arguments`` and + ``ModelInstance.dict`` (which iterate ``__dict__`` and filter + underscore-prefixed keys) do not propagate the cached string + onto the constructed ``ModelInstance``. A plain + ``functools.cached_property`` writes to ``__dict__[name]`` + without a leading underscore, which would leak the string as + a non-array JAX pytree leaf and break ``jax.jit(fit_from)``. + """ + cached = self.__dict__.get("_parameterization_cache") + if cached is not None: + return cached + from .prior_model import Model formatter = TextFormatter(line_length=info_whitespace()) @@ -1900,6 +1913,7 @@ def parameterization(self) -> str: for group in find_groups(paths, limit=0): formatter.add(*group) + self.__dict__["_parameterization_cache"] = formatter.text return formatter.text @property diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 9b95d48f5..ce49b81a7 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -122,6 +122,12 @@ def __init__( self.use_jax_vmap = use_jax_vmap self.use_jax_jit = use_jax_jit + if getattr(self.analysis, "_use_jax", False): + from autofit.jax.pytrees import enable_pytrees, register_model + + enable_pytrees() + register_model(self.model) + self._call = self.call if self.use_jax_vmap: diff --git a/test_autofit/mapper/test_parameterization.py b/test_autofit/mapper/test_parameterization.py index ec3667a8b..416c07313 100644 --- a/test_autofit/mapper/test_parameterization.py +++ b/test_autofit/mapper/test_parameterization.py @@ -141,6 +141,41 @@ def test_tuple_instance_model_info(self, mapper): assert len(info.split("\n")) == len(mapper.info.split("\n")) +def test_parameterization_cache_does_not_leak_into_instance(): + """Regression: ``parameterization`` is cached in + ``self.__dict__["_parameterization_cache"]`` so that + ``Collection._instance_for_arguments`` and ``ModelInstance.dict`` + (which skip underscore-prefixed keys) do not propagate the cached + string onto the constructed instance. A plain + ``functools.cached_property`` would write to ``__dict__["parameterization"]`` + without an underscore, leaking the string into ``ModelInstance.dict`` + and downstream JAX pytree flattening — see commit 4564ae9a1.""" + + model = af.Collection(gaussian=af.Model(af.ex.Gaussian)) + + # Touch model.info → exercises the same propagation path that every + # workspace script hits at construction time. + _ = model.info + _ = model.parameterization # second access uses the cache + + # The cache must live behind an underscore key on the model. + assert "_parameterization_cache" in model.__dict__ + assert "parameterization" not in model.__dict__ + + instance = model.instance_from_prior_medians() + + # Neither the cached key nor the public name may appear on the + # constructed instance. + assert "parameterization" not in instance.__dict__ + assert "_parameterization_cache" not in instance.__dict__ + assert "parameterization" not in instance.dict + assert "_parameterization_cache" not in instance.dict + + # The instance must yield only model components when iterated. + for child in instance: + assert not isinstance(child, str) + + def test_integer_attributes(): model = af.Model(af.ex.Gaussian)