Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,14 @@ def register_instance_pytree(cls, no_flatten=()):
if cls in _pytree_registered_classes:
return
from autoconf.jax_wrapper import register_pytree_node
from autoconf.tools.decorators import cached_property_names

no_flatten_set = frozenset(no_flatten)
# Extend the caller-supplied no_flatten set with every
# ``cached_property``-style descriptor on ``cls`` so derived caches
# (e.g. heavy plotting/aggregator @cached_property values) never reach
# the pytree leaves and break ``jax.jit``. Mirrors the PyAutoFit-side
# defense; see PyAutoFit#1300 for the diagnosed class of bug.
no_flatten_set = frozenset(no_flatten) | cached_property_names(cls)

def flatten(instance):
dyn: list = []
Expand Down Expand Up @@ -174,12 +180,19 @@ def instance_flatten(cls, instance):
"""
Flatten an instance of an autoarray class into a tuple of its attributes (i.e.. a pytree)
"""
from autoconf.tools.decorators import cached_property_names

# Union the class-level ``__no_flatten__`` opt-out with any
# ``cached_property`` descriptor names so derived caches don't
# surface as pytree leaves. Mirrors the defense in
# ``_register_as_pytree`` (PyAutoFit#1300 follow-up).
excluded = set(cls.__no_flatten__) | cached_property_names(cls)
keys, values = zip(
*sorted(
{
key: value
for key, value in instance.__dict__.items()
if key not in cls.__no_flatten__
if key not in excluded
}.items()
)
)
Expand Down
82 changes: 82 additions & 0 deletions test_autoarray/test_abstract_ndarray_pytree_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Regression: cached_property descriptors on AbstractNDArray subclasses
must be filtered from ``instance_flatten`` so derived caches never reach
the JAX pytree leaves.

NumPy-only per the project rule [[feedback_no_jax_in_unit_tests]]:
exercise the ``instance_flatten`` classmethod directly (which is what
the JAX pytree path delegates to) and assert composition is correct.
"""

import functools

import numpy as np

from autoarray.abstract_ndarray import AbstractNDArray


class _FakeArray(AbstractNDArray):
"""Minimal AbstractNDArray subclass that adds a ``@cached_property``
returning a string. Used to assert the guard filters it from
``instance_flatten``."""

__no_flatten__ = ("use_jax",)

def __init__(self, array):
# Skip AbstractNDArray.__init__ to avoid the JAX-registration path
# — we only need the dict-shape for the flatten test.
self._array = np.asarray(array)
self._is_transformed = False
self.use_jax = False

@property
def native(self):
# AbstractNDArray declares ``native`` abstract; the body is
# irrelevant to the flatten path so just echo ``_array``.
return self._array

@functools.cached_property
def heavy_summary(self):
return "a-pretty-printed-summary-of-the-array"


def test_instance_flatten_excludes_cached_property_names():
"""``AbstractNDArray.instance_flatten`` unions the class-level
``__no_flatten__`` with the result of
``autoconf.tools.decorators.cached_property_names`` so derived
cached strings stay out of the pytree leaves.

This pins the structural defense that follows PyAutoFit#1300: the
leak surfaces today only on the Model side, but the same opt-out
filter shape on AbstractNDArray descendants would break ``jax.jit``
the moment anyone added a ``@cached_property`` returning a
non-array value to a Fit class."""

arr = _FakeArray([1.0, 2.0, 3.0])

# Trigger the cached property: it writes "...summary..." into __dict__.
_ = arr.heavy_summary
assert arr.__dict__["heavy_summary"] == "a-pretty-printed-summary-of-the-array"

leaves, keys = _FakeArray.instance_flatten(arr)

# The pre-existing __no_flatten__ exclusion ("use_jax") still applies.
assert "use_jax" not in keys
# The new cached_property exclusion fires too.
assert "heavy_summary" not in keys
# No string leaves anywhere.
assert not any(isinstance(leaf, str) for leaf in leaves)


def test_instance_flatten_preserves_array_data():
"""Sanity check: filtering cached_property names does not collateral-
damage real array data. The underlying numpy array must still appear
in the leaves."""

arr = _FakeArray([1.0, 2.0, 3.0])
_ = arr.heavy_summary # poison the cache before flattening

leaves, keys = _FakeArray.instance_flatten(arr)

assert "_array" in keys
array_index = keys.index("_array")
np.testing.assert_array_equal(leaves[array_index], np.asarray([1.0, 2.0, 3.0]))
Loading