Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions autogalaxy/profiles/light/linear/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ def __eq__(self, other):
and self.pytree_token == other.pytree_token
)

def __getstate__(self):
# pytree_token is a process-local counter increment used only as a
# stable hash/eq identity for jax.jit flatten/unflatten round-trips.
# Persisting it would copy a number from one process's counter into
# another's — the JAX path uses register_model's attr_const which
# reads vars(self) directly and is unaffected by __getstate__.
return {k: v for k, v in self.__dict__.items() if k != "pytree_token"}

def __setstate__(self, state):
self.__dict__.update(state)
if "pytree_token" not in state:
self.pytree_token = next(LightProfileLinear._pytree_token_counter)

@property
def regularization(self):
return None
Expand Down
53 changes: 53 additions & 0 deletions test_autogalaxy/profiles/light/linear/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,56 @@ def test__lp_instance_from__returns_instance_with_correct_intensity():
)

assert lp_non_linear.intensity == 3.0


def test__pytree_token_is_int_and_unique():
lp_0 = ag.lp_linear.Sersic()
lp_1 = ag.lp_linear.Sersic()

assert isinstance(lp_0.pytree_token, int)
assert isinstance(lp_1.pytree_token, int)
assert lp_0.pytree_token != lp_1.pytree_token

assert isinstance(hash(lp_0), int)
assert hash(lp_0) == hash(lp_0)
assert hash(lp_0) != hash(lp_1)


def test__getstate__omits_pytree_token():
lp = ag.lp_linear.Sersic()
state = lp.__getstate__()

assert "pytree_token" not in state
assert "effective_radius" in state


def test__setstate__assigns_fresh_pytree_token_when_missing():
lp = ag.lp_linear.Sersic()
state = lp.__getstate__()

restored = ag.lp_linear.Sersic.__new__(ag.lp_linear.Sersic)
restored.__setstate__(state)

assert isinstance(restored.pytree_token, int)
assert isinstance(hash(restored), int)


def test__pickle_roundtrip_preserves_int_hash():
import pickle

lp = ag.lp_linear.Sersic()
restored = pickle.loads(pickle.dumps(lp))

assert isinstance(hash(restored), int)
assert isinstance(restored.pytree_token, int)
assert restored.effective_radius == lp.effective_radius


def test__setstate__preserves_pytree_token_when_present():
lp = ag.lp_linear.Sersic()
state_with_token = dict(lp.__dict__)

restored = ag.lp_linear.Sersic.__new__(ag.lp_linear.Sersic)
restored.__setstate__(state_with_token)

assert restored.pytree_token == lp.pytree_token
Loading