From 48bcb809027ab89021132fea5b920fe51f296d64 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 28 May 2026 19:38:22 +0100 Subject: [PATCH] fix(jax): keep parameterization cache off ModelInstance + auto-register pytrees MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two coupled fixes restoring the JAX `jit(fit_from)` path that broke when commit 4564ae9a1 made `AbstractPriorModel.parameterization` a `functools.cached_property`. `cached_property` writes to `self.__dict__["parameterization"]`. After any `model.info` access, `Collection._instance_for_arguments` (which iterates `__dict__` and skips only underscore-prefixed keys) propagates the cached string onto every `ModelInstance`. The string then surfaces as a non-array JAX pytree leaf (autogalaxy_workspace_test + autolens_workspace_test `jax_likelihood_functions/*` — 38 scripts) and makes `for x in instance:` yield strings instead of profiles (autofit_workspace `overview/overview_1_the_basics.py`). Fix 1: store the cache under the underscore-prefixed key `_parameterization_cache` so both `Collection._instance_for_arguments` and `ModelInstance.dict` filter it out. Preserves the 2.7s → 0.05s perf win from 4564ae9a1. Fix 2: auto-call `enable_pytrees() + register_model(self.model)` from `Fitness.__init__` whenever `analysis._use_jax=True`. Both helpers are idempotent, so workspaces that still call them explicitly keep working. New JAX-enabled workspaces don't need the boilerplate. Verified locally: - 1413/1413 PyAutoFit unit tests pass + new `test_parameterization_cache_does_not_leak_into_instance` regression - `autofit_workspace/scripts/overview/overview_1_the_basics.py` runs to completion (cluster C4 reproducer) - `autolens_workspace_test/scripts/jax_likelihood_functions/imaging/rectangular.py` prints "PASS: jit(fit_from) round-trip matches NumPy scalar" (cluster C1 reproducer) Follow-up: a structural defense across the four `__dict__`-iterators in `autofit/mapper/` plus `autoarray/abstract_ndarray.py` will ship as a separate PR — a `_cached_property_names(cls)` classmethod applied as an extra filter at every leak site so the next future `@cached_property` on a model class cannot reintroduce this bug. Co-Authored-By: Claude Opus 4.7 (1M context) --- autofit/mapper/prior_model/abstract.py | 22 +++++++++--- autofit/non_linear/fitness.py | 6 ++++ test_autofit/mapper/test_parameterization.py | 35 ++++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index 0f3e16dc9..ba9d0ea99 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1,5 +1,4 @@ import copy -import functools import inspect import json import logging @@ -1860,12 +1859,26 @@ def order_no(self) -> str: ] return ":".join(values) - @functools.cached_property + @property def parameterization(self) -> str: """ Describes the path to each of the PriorModels, its class - and its number of free parameters - """ + and its number of free parameters. + + Cached on first access in ``self.__dict__`` under the + ``_`` -prefixed key ``_parameterization_cache`` so that + ``Collection._instance_for_arguments`` and + ``ModelInstance.dict`` (which iterate ``__dict__`` and filter + underscore-prefixed keys) do not propagate the cached string + onto the constructed ``ModelInstance``. A plain + ``functools.cached_property`` writes to ``__dict__[name]`` + without a leading underscore, which would leak the string as + a non-array JAX pytree leaf and break ``jax.jit(fit_from)``. + """ + cached = self.__dict__.get("_parameterization_cache") + if cached is not None: + return cached + from .prior_model import Model formatter = TextFormatter(line_length=info_whitespace()) @@ -1900,6 +1913,7 @@ def parameterization(self) -> str: for group in find_groups(paths, limit=0): formatter.add(*group) + self.__dict__["_parameterization_cache"] = formatter.text return formatter.text @property diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 9b95d48f5..ce49b81a7 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -122,6 +122,12 @@ def __init__( self.use_jax_vmap = use_jax_vmap self.use_jax_jit = use_jax_jit + if getattr(self.analysis, "_use_jax", False): + from autofit.jax.pytrees import enable_pytrees, register_model + + enable_pytrees() + register_model(self.model) + self._call = self.call if self.use_jax_vmap: diff --git a/test_autofit/mapper/test_parameterization.py b/test_autofit/mapper/test_parameterization.py index ec3667a8b..416c07313 100644 --- a/test_autofit/mapper/test_parameterization.py +++ b/test_autofit/mapper/test_parameterization.py @@ -141,6 +141,41 @@ def test_tuple_instance_model_info(self, mapper): assert len(info.split("\n")) == len(mapper.info.split("\n")) +def test_parameterization_cache_does_not_leak_into_instance(): + """Regression: ``parameterization`` is cached in + ``self.__dict__["_parameterization_cache"]`` so that + ``Collection._instance_for_arguments`` and ``ModelInstance.dict`` + (which skip underscore-prefixed keys) do not propagate the cached + string onto the constructed instance. A plain + ``functools.cached_property`` would write to ``__dict__["parameterization"]`` + without an underscore, leaking the string into ``ModelInstance.dict`` + and downstream JAX pytree flattening — see commit 4564ae9a1.""" + + model = af.Collection(gaussian=af.Model(af.ex.Gaussian)) + + # Touch model.info → exercises the same propagation path that every + # workspace script hits at construction time. + _ = model.info + _ = model.parameterization # second access uses the cache + + # The cache must live behind an underscore key on the model. + assert "_parameterization_cache" in model.__dict__ + assert "parameterization" not in model.__dict__ + + instance = model.instance_from_prior_medians() + + # Neither the cached key nor the public name may appear on the + # constructed instance. + assert "parameterization" not in instance.__dict__ + assert "_parameterization_cache" not in instance.__dict__ + assert "parameterization" not in instance.dict + assert "_parameterization_cache" not in instance.dict + + # The instance must yield only model components when iterated. + for child in instance: + assert not isinstance(child, str) + + def test_integer_attributes(): model = af.Model(af.ex.Gaussian)