From 70288f55786b2837f1cd8211a096863ec12fa965 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 28 Apr 2026 13:36:33 +0100 Subject: [PATCH] fix(light/linear): mark pytree_token as ephemeral via __getstate__/__setstate__ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pytree_token is a process-local counter increment used as a stable hash/eq identity for LightProfileLinear instances across jax.jit flatten/unflatten. It is not meaningful state to persist — copying a number from one process's counter into another's does not preserve identity across processes. PyAutoFit's database serializer routes every numeric attribute through the Value SQL row (sa.Float column), so persisting pytree_token as an int and reading it back yields a float. Python >=3.12 strictly requires __hash__ to return int, so any dict keyed on a LightProfileLinear (e.g. the linear_light_profile_intensity_dict in autogalaxy/abstract_fit.py) raised TypeError on the visualization path after a fit was loaded via FitImagingAgg. Add __getstate__ that omits pytree_token and __setstate__ that assigns a fresh value if missing. PyAutoFit's Instance._from_object and Object.__call__ already check for these methods and honour them, so no PyAutoFit changes are required. The JAX-jit path uses register_model's attr_const which reads vars(self) directly and is unaffected. Fixes the smoke-test failure on autolens_workspace_test/main: scripts/database/scrape/general.py. Co-Authored-By: Claude Opus 4.7 (1M context) --- autogalaxy/profiles/light/linear/abstract.py | 13 +++++ .../profiles/light/linear/test_abstract.py | 53 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/autogalaxy/profiles/light/linear/abstract.py b/autogalaxy/profiles/light/linear/abstract.py index 2af05f7e..56656cf1 100644 --- a/autogalaxy/profiles/light/linear/abstract.py +++ b/autogalaxy/profiles/light/linear/abstract.py @@ -62,6 +62,19 @@ def __eq__(self, other): and self.pytree_token == other.pytree_token ) + def __getstate__(self): + # pytree_token is a process-local counter increment used only as a + # stable hash/eq identity for jax.jit flatten/unflatten round-trips. + # Persisting it would copy a number from one process's counter into + # another's — the JAX path uses register_model's attr_const which + # reads vars(self) directly and is unaffected by __getstate__. + return {k: v for k, v in self.__dict__.items() if k != "pytree_token"} + + def __setstate__(self, state): + self.__dict__.update(state) + if "pytree_token" not in state: + self.pytree_token = next(LightProfileLinear._pytree_token_counter) + @property def regularization(self): return None diff --git a/test_autogalaxy/profiles/light/linear/test_abstract.py b/test_autogalaxy/profiles/light/linear/test_abstract.py index bbcdff0d..6fe4c83e 100644 --- a/test_autogalaxy/profiles/light/linear/test_abstract.py +++ b/test_autogalaxy/profiles/light/linear/test_abstract.py @@ -98,3 +98,56 @@ def test__lp_instance_from__returns_instance_with_correct_intensity(): ) assert lp_non_linear.intensity == 3.0 + + +def test__pytree_token_is_int_and_unique(): + lp_0 = ag.lp_linear.Sersic() + lp_1 = ag.lp_linear.Sersic() + + assert isinstance(lp_0.pytree_token, int) + assert isinstance(lp_1.pytree_token, int) + assert lp_0.pytree_token != lp_1.pytree_token + + assert isinstance(hash(lp_0), int) + assert hash(lp_0) == hash(lp_0) + assert hash(lp_0) != hash(lp_1) + + +def test__getstate__omits_pytree_token(): + lp = ag.lp_linear.Sersic() + state = lp.__getstate__() + + assert "pytree_token" not in state + assert "effective_radius" in state + + +def test__setstate__assigns_fresh_pytree_token_when_missing(): + lp = ag.lp_linear.Sersic() + state = lp.__getstate__() + + restored = ag.lp_linear.Sersic.__new__(ag.lp_linear.Sersic) + restored.__setstate__(state) + + assert isinstance(restored.pytree_token, int) + assert isinstance(hash(restored), int) + + +def test__pickle_roundtrip_preserves_int_hash(): + import pickle + + lp = ag.lp_linear.Sersic() + restored = pickle.loads(pickle.dumps(lp)) + + assert isinstance(hash(restored), int) + assert isinstance(restored.pytree_token, int) + assert restored.effective_radius == lp.effective_radius + + +def test__setstate__preserves_pytree_token_when_present(): + lp = ag.lp_linear.Sersic() + state_with_token = dict(lp.__dict__) + + restored = ag.lp_linear.Sersic.__new__(ag.lp_linear.Sersic) + restored.__setstate__(state_with_token) + + assert restored.pytree_token == lp.pytree_token