Skip to content

fix(jax): exclude cached_property descriptors from pytree flatten paths#343

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/cached-property-pytree-guard
May 29, 2026
Merged

fix(jax): exclude cached_property descriptors from pytree flatten paths#343
Jammy2211 merged 1 commit into
mainfrom
feature/cached-property-pytree-guard

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Follow-up to PyAutoFit #1300. Two __dict__-walking sites in abstract_ndarray.py register or flatten AbstractNDArray instances as JAX pytrees:

  1. _register_as_pytree(cls, ..., no_flatten=()) — the generic register_pytree_node path used for FitImaging / Tracer / DatasetModel.
  2. AbstractNDArray.instance_flatten(cls, instance) — the classmethod referenced by register_pytree_node_class callers.

Both use additive filters (no_flatten_set / cls.__no_flatten__). Today neither includes cached_property descriptor names, so any future @cached_property returning a non-array value on a Fit class would silently surface as a pytree leaf and break jax.jit at the boundary — same class of bug as PyAutoFit #1300 on the model side.

Union the existing exclusion sets with cached_property_names(cls) from PyAutoLabs/PyAutoConf#111 at both sites.

Test plan

  • pytest test_autoarray/test_abstract_ndarray_pytree_guard.py -v — 2/2 pass.
  • pytest test_autoarray -q — 840/840 pass (838 prior + 2 new). Numpy-only per feedback_no_jax_in_unit_tests.

Dependency

Depends on PyAutoLabs/PyAutoConf#111 (must merge first — provides the cached_property_names MRO walker).

🤖 Generated with Claude Code

Follow-up to PyAutoFit #1300. Two `__dict__`-walking sites in
abstract_ndarray.py register/flatten AbstractNDArray instances as JAX
pytrees:

1. `_register_as_pytree(cls, ..., no_flatten=())` — the generic
   `register_pytree_node` path used for FitImaging / Tracer /
   DatasetModel.
2. `AbstractNDArray.instance_flatten(cls, instance)` — the classmethod
   referenced by `register_pytree_node_class` callers.

Both use additive filters (`no_flatten_set` / `cls.__no_flatten__`)
that today only honour caller-supplied opt-outs. The pre-existing
filters don't include `cached_property` descriptor names, so any
future `@cached_property` returning a non-array value on a Fit class
would silently surface as a pytree leaf and break jax.jit at the
boundary (same class of bug as PyAutoFit #1300 on the Model side).

Union the existing exclusion sets with `cached_property_names(cls)`
from PyAutoConf #111 at both sites.

2 numpy-only regression tests cover:
- `instance_flatten` returns no string leaves when the class has a
  `@functools.cached_property` (after the cache is populated).
- The pre-existing `__no_flatten__` field exclusion is preserved
  alongside the new cached_property exclusion.
- Real array data still flows through the leaves (no collateral
  damage).

840/840 PyAutoArray tests pass (838 prior + 2 new).

Depends on: PyAutoLabs/PyAutoConf#111.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 1b6e4b2 into main May 29, 2026
2 of 6 checks passed
@Jammy2211 Jammy2211 deleted the feature/cached-property-pytree-guard branch May 29, 2026 07:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant