diff --git a/autogalaxy/profiles/light/linear/abstract.py b/autogalaxy/profiles/light/linear/abstract.py index 2af05f7e..56656cf1 100644 --- a/autogalaxy/profiles/light/linear/abstract.py +++ b/autogalaxy/profiles/light/linear/abstract.py @@ -62,6 +62,19 @@ def __eq__(self, other): and self.pytree_token == other.pytree_token ) + def __getstate__(self): + # pytree_token is a process-local counter increment used only as a + # stable hash/eq identity for jax.jit flatten/unflatten round-trips. + # Persisting it would copy a number from one process's counter into + # another's — the JAX path uses register_model's attr_const which + # reads vars(self) directly and is unaffected by __getstate__. + return {k: v for k, v in self.__dict__.items() if k != "pytree_token"} + + def __setstate__(self, state): + self.__dict__.update(state) + if "pytree_token" not in state: + self.pytree_token = next(LightProfileLinear._pytree_token_counter) + @property def regularization(self): return None diff --git a/test_autogalaxy/profiles/light/linear/test_abstract.py b/test_autogalaxy/profiles/light/linear/test_abstract.py index bbcdff0d..6fe4c83e 100644 --- a/test_autogalaxy/profiles/light/linear/test_abstract.py +++ b/test_autogalaxy/profiles/light/linear/test_abstract.py @@ -98,3 +98,56 @@ def test__lp_instance_from__returns_instance_with_correct_intensity(): ) assert lp_non_linear.intensity == 3.0 + + +def test__pytree_token_is_int_and_unique(): + lp_0 = ag.lp_linear.Sersic() + lp_1 = ag.lp_linear.Sersic() + + assert isinstance(lp_0.pytree_token, int) + assert isinstance(lp_1.pytree_token, int) + assert lp_0.pytree_token != lp_1.pytree_token + + assert isinstance(hash(lp_0), int) + assert hash(lp_0) == hash(lp_0) + assert hash(lp_0) != hash(lp_1) + + +def test__getstate__omits_pytree_token(): + lp = ag.lp_linear.Sersic() + state = lp.__getstate__() + + assert "pytree_token" not in state + assert "effective_radius" in state + + +def test__setstate__assigns_fresh_pytree_token_when_missing(): + lp = ag.lp_linear.Sersic() + state = lp.__getstate__() + + restored = ag.lp_linear.Sersic.__new__(ag.lp_linear.Sersic) + restored.__setstate__(state) + + assert isinstance(restored.pytree_token, int) + assert isinstance(hash(restored), int) + + +def test__pickle_roundtrip_preserves_int_hash(): + import pickle + + lp = ag.lp_linear.Sersic() + restored = pickle.loads(pickle.dumps(lp)) + + assert isinstance(hash(restored), int) + assert isinstance(restored.pytree_token, int) + assert restored.effective_radius == lp.effective_radius + + +def test__setstate__preserves_pytree_token_when_present(): + lp = ag.lp_linear.Sersic() + state_with_token = dict(lp.__dict__) + + restored = ag.lp_linear.Sersic.__new__(ag.lp_linear.Sersic) + restored.__setstate__(state_with_token) + + assert restored.pytree_token == lp.pytree_token