diff --git a/autoarray/abstract_ndarray.py b/autoarray/abstract_ndarray.py index f25a7d8e..5c64c490 100644 --- a/autoarray/abstract_ndarray.py +++ b/autoarray/abstract_ndarray.py @@ -102,8 +102,14 @@ def register_instance_pytree(cls, no_flatten=()): if cls in _pytree_registered_classes: return from autoconf.jax_wrapper import register_pytree_node + from autoconf.tools.decorators import cached_property_names - no_flatten_set = frozenset(no_flatten) + # Extend the caller-supplied no_flatten set with every + # ``cached_property``-style descriptor on ``cls`` so derived caches + # (e.g. heavy plotting/aggregator @cached_property values) never reach + # the pytree leaves and break ``jax.jit``. Mirrors the PyAutoFit-side + # defense; see PyAutoFit#1300 for the diagnosed class of bug. + no_flatten_set = frozenset(no_flatten) | cached_property_names(cls) def flatten(instance): dyn: list = [] @@ -174,12 +180,19 @@ def instance_flatten(cls, instance): """ Flatten an instance of an autoarray class into a tuple of its attributes (i.e.. a pytree) """ + from autoconf.tools.decorators import cached_property_names + + # Union the class-level ``__no_flatten__`` opt-out with any + # ``cached_property`` descriptor names so derived caches don't + # surface as pytree leaves. Mirrors the defense in + # ``_register_as_pytree`` (PyAutoFit#1300 follow-up). + excluded = set(cls.__no_flatten__) | cached_property_names(cls) keys, values = zip( *sorted( { key: value for key, value in instance.__dict__.items() - if key not in cls.__no_flatten__ + if key not in excluded }.items() ) ) diff --git a/test_autoarray/test_abstract_ndarray_pytree_guard.py b/test_autoarray/test_abstract_ndarray_pytree_guard.py new file mode 100644 index 00000000..ddea04a1 --- /dev/null +++ b/test_autoarray/test_abstract_ndarray_pytree_guard.py @@ -0,0 +1,82 @@ +"""Regression: cached_property descriptors on AbstractNDArray subclasses +must be filtered from ``instance_flatten`` so derived caches never reach +the JAX pytree leaves. + +NumPy-only per the project rule [[feedback_no_jax_in_unit_tests]]: +exercise the ``instance_flatten`` classmethod directly (which is what +the JAX pytree path delegates to) and assert composition is correct. +""" + +import functools + +import numpy as np + +from autoarray.abstract_ndarray import AbstractNDArray + + +class _FakeArray(AbstractNDArray): + """Minimal AbstractNDArray subclass that adds a ``@cached_property`` + returning a string. Used to assert the guard filters it from + ``instance_flatten``.""" + + __no_flatten__ = ("use_jax",) + + def __init__(self, array): + # Skip AbstractNDArray.__init__ to avoid the JAX-registration path + # — we only need the dict-shape for the flatten test. + self._array = np.asarray(array) + self._is_transformed = False + self.use_jax = False + + @property + def native(self): + # AbstractNDArray declares ``native`` abstract; the body is + # irrelevant to the flatten path so just echo ``_array``. + return self._array + + @functools.cached_property + def heavy_summary(self): + return "a-pretty-printed-summary-of-the-array" + + +def test_instance_flatten_excludes_cached_property_names(): + """``AbstractNDArray.instance_flatten`` unions the class-level + ``__no_flatten__`` with the result of + ``autoconf.tools.decorators.cached_property_names`` so derived + cached strings stay out of the pytree leaves. + + This pins the structural defense that follows PyAutoFit#1300: the + leak surfaces today only on the Model side, but the same opt-out + filter shape on AbstractNDArray descendants would break ``jax.jit`` + the moment anyone added a ``@cached_property`` returning a + non-array value to a Fit class.""" + + arr = _FakeArray([1.0, 2.0, 3.0]) + + # Trigger the cached property: it writes "...summary..." into __dict__. + _ = arr.heavy_summary + assert arr.__dict__["heavy_summary"] == "a-pretty-printed-summary-of-the-array" + + leaves, keys = _FakeArray.instance_flatten(arr) + + # The pre-existing __no_flatten__ exclusion ("use_jax") still applies. + assert "use_jax" not in keys + # The new cached_property exclusion fires too. + assert "heavy_summary" not in keys + # No string leaves anywhere. + assert not any(isinstance(leaf, str) for leaf in leaves) + + +def test_instance_flatten_preserves_array_data(): + """Sanity check: filtering cached_property names does not collateral- + damage real array data. The underlying numpy array must still appear + in the leaves.""" + + arr = _FakeArray([1.0, 2.0, 3.0]) + _ = arr.heavy_summary # poison the cache before flattening + + leaves, keys = _FakeArray.instance_flatten(arr) + + assert "_array" in keys + array_index = keys.index("_array") + np.testing.assert_array_equal(leaves[array_index], np.asarray([1.0, 2.0, 3.0]))