fix(jax): exclude cached_property descriptors from pytree flatten paths#343
Merged
Conversation
Follow-up to PyAutoFit #1300. Two `__dict__`-walking sites in abstract_ndarray.py register/flatten AbstractNDArray instances as JAX pytrees: 1. `_register_as_pytree(cls, ..., no_flatten=())` — the generic `register_pytree_node` path used for FitImaging / Tracer / DatasetModel. 2. `AbstractNDArray.instance_flatten(cls, instance)` — the classmethod referenced by `register_pytree_node_class` callers. Both use additive filters (`no_flatten_set` / `cls.__no_flatten__`) that today only honour caller-supplied opt-outs. The pre-existing filters don't include `cached_property` descriptor names, so any future `@cached_property` returning a non-array value on a Fit class would silently surface as a pytree leaf and break jax.jit at the boundary (same class of bug as PyAutoFit #1300 on the Model side). Union the existing exclusion sets with `cached_property_names(cls)` from PyAutoConf #111 at both sites. 2 numpy-only regression tests cover: - `instance_flatten` returns no string leaves when the class has a `@functools.cached_property` (after the cache is populated). - The pre-existing `__no_flatten__` field exclusion is preserved alongside the new cached_property exclusion. - Real array data still flows through the leaves (no collateral damage). 840/840 PyAutoArray tests pass (838 prior + 2 new). Depends on: PyAutoLabs/PyAutoConf#111. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up to PyAutoFit #1300. Two
__dict__-walking sites inabstract_ndarray.pyregister or flattenAbstractNDArrayinstances as JAX pytrees:_register_as_pytree(cls, ..., no_flatten=())— the genericregister_pytree_nodepath used for FitImaging / Tracer / DatasetModel.AbstractNDArray.instance_flatten(cls, instance)— the classmethod referenced byregister_pytree_node_classcallers.Both use additive filters (
no_flatten_set/cls.__no_flatten__). Today neither includescached_propertydescriptor names, so any future@cached_propertyreturning a non-array value on a Fit class would silently surface as a pytree leaf and breakjax.jitat the boundary — same class of bug as PyAutoFit #1300 on the model side.Union the existing exclusion sets with
cached_property_names(cls)from PyAutoLabs/PyAutoConf#111 at both sites.Test plan
pytest test_autoarray/test_abstract_ndarray_pytree_guard.py -v— 2/2 pass.pytest test_autoarray -q— 840/840 pass (838 prior + 2 new). Numpy-only perfeedback_no_jax_in_unit_tests.Dependency
Depends on PyAutoLabs/PyAutoConf#111 (must merge first — provides the
cached_property_namesMRO walker).🤖 Generated with Claude Code