Skip to content

Commit 697e3b5

Browse files
committed
Add assertion to MultiWriterTokens
1 parent bf09110 commit 697e3b5

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

synapse/types/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
458458
represented by a default `stream` attribute and a map of instance name to
459459
stream position of any writers that are ahead of the default stream
460460
position.
461+
462+
The values in `instance_map` must be greater than the `stream` attribute.
461463
"""
462464

463465
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
@@ -472,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
472474
kw_only=True,
473475
)
474476

477+
def __attrs_post_init__(self):
478+
# Enforce that all instances have a value greater than the min stream
479+
# position.
480+
for v in self.instance_map.values():
481+
if v < self.stream:
482+
raise ValueError(
483+
"'instance_map' includes a stream position before the main 'stream' attribute"
484+
)
485+
475486
@classmethod
476487
@abc.abstractmethod
477488
async def parse(cls, store: "DataStore", string: str) -> "Self":
@@ -641,6 +652,8 @@ def __attrs_post_init__(self) -> None:
641652
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
642653
)
643654

655+
super().__attrs_post_init__()
656+
644657
@classmethod
645658
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
646659
try:

tests/test_types.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,9 @@ def test_instance_map(self) -> None:
161161
parsed_token = self.get_success(RoomStreamToken.parse(store, string_token))
162162
self.assertEqual(parsed_token, token)
163163

164-
@skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
165-
def test_instance_map_behind(self) -> None:
166-
"""Test for stream token with instance map, where instance map entries
167-
are from before stream token."""
168-
store = self.hs.get_datastores().main
164+
def test_instance_map_assertion(self) -> None:
165+
"""Test that we assert values in the instance map are greater than the
166+
min stream position"""
169167

170-
token = RoomStreamToken(stream=5, instance_map=immutabledict({"foo": 4}))
171-
172-
string_token = self.get_success(token.to_string(store))
173-
self.assertEqual(string_token, "s5")
174-
175-
parsed_token = self.get_success(RoomStreamToken.parse(store, string_token))
176-
self.assertEqual(parsed_token, RoomStreamToken(stream=5))
168+
with self.assertRaises(ValueError):
169+
RoomStreamToken(stream=5, instance_map=immutabledict({"foo": 4}))

0 commit comments

Comments
 (0)