Skip to content

Commit 8201e58

Browse files
authored
Update and stabilize mutual rooms support (MSC2666) (#19511)
Updates the error codes to match MSC2666 changes (user ID query param validation + proper errcode for requesting rooms with self), added the new `count` field, and stabilized the endpoint.
1 parent 3d960d8 commit 8201e58

File tree

6 files changed

+110
-43
lines changed

6 files changed

+110
-43
lines changed

changelog.d/19511.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update and stabilize support for [MSC2666](https://github.com/matrix-org/matrix-spec-proposals/pull/2666): Get rooms in common with another user. Contributed by @tulir @ Beeper.

synapse/config/experimental.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,6 @@ def read_config(
422422
# previously calculated push actions.
423423
self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False)
424424

425-
# MSC2666: Query mutual rooms between two users.
426-
self.msc2666_enabled: bool = experimental.get("msc2666_enabled", False)
427-
428425
# MSC2815 (allow room moderators to view redacted event content)
429426
self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False)
430427

synapse/rest/client/mutual_rooms.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from synapse.http.server import HttpServer
3030
from synapse.http.servlet import RestServlet, parse_strings_from_args
3131
from synapse.http.site import SynapseRequest
32-
from synapse.types import JsonDict
32+
from synapse.types import JsonDict, UserID
3333

3434
from ._base import client_patterns
3535

@@ -65,13 +65,10 @@ def _parse_mutual_rooms_batch_token_args(args: dict[bytes, list[bytes]]) -> str
6565

6666
class UserMutualRoomsServlet(RestServlet):
6767
"""
68-
GET /uk.half-shot.msc2666/user/mutual_rooms?user_id={user_id}&from={token} HTTP/1.1
68+
GET /mutual_rooms?user_id={user_id}&from={token} HTTP/1.1
6969
"""
7070

71-
PATTERNS = client_patterns(
72-
"/uk.half-shot.msc2666/user/mutual_rooms$",
73-
releases=(), # This is an unstable feature
74-
)
71+
PATTERNS = [*client_patterns("/mutual_rooms$", releases=("v1",))]
7572

7673
def __init__(self, hs: "HomeServer"):
7774
super().__init__()
@@ -82,7 +79,9 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
8279
# twisted.web.server.Request.args is incorrectly defined as Any | None
8380
args: dict[bytes, list[bytes]] = request.args # type: ignore
8481

85-
user_ids = parse_strings_from_args(args, "user_id", required=True)
82+
user_ids = parse_strings_from_args(
83+
args, "user_id", required=True, encoding="utf-8"
84+
)
8685
from_batch = _parse_mutual_rooms_batch_token_args(args)
8786

8887
if len(user_ids) > 1:
@@ -93,13 +92,19 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
9392
)
9493

9594
user_id = user_ids[0]
95+
if not UserID.is_valid_strict(user_id):
96+
raise SynapseError(
97+
HTTPStatus.BAD_REQUEST,
98+
"Invalid user_id query parameter",
99+
errcode=Codes.INVALID_PARAM,
100+
)
96101

97102
requester = await self.auth.get_user_by_req(request)
98103
if user_id == requester.user.to_string():
99104
raise SynapseError(
100105
HTTPStatus.BAD_REQUEST,
101106
"You cannot request a list of shared rooms with yourself",
102-
errcode=Codes.UNKNOWN,
107+
errcode=Codes.INVALID_PARAM,
103108
)
104109

105110
# Sort here instead of the database function, so that we don't expose
@@ -109,6 +114,7 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
109114
frozenset((requester.user.to_string(), user_id))
110115
)
111116
)
117+
total_count = len(rooms)
112118

113119
if from_batch:
114120
# A from_batch token was provided, so cut off any rooms where the ID is
@@ -123,7 +129,7 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
123129

124130
if len(rooms) <= MUTUAL_ROOMS_BATCH_LIMIT:
125131
# We've reached the end of the list, don't return a batch token
126-
return 200, {"joined": rooms}
132+
return 200, {"joined": rooms, "count": total_count}
127133

128134
rooms = rooms[:MUTUAL_ROOMS_BATCH_LIMIT]
129135
# We use urlsafe unpadded base64 encoding for the batch token in order to
@@ -135,11 +141,14 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
135141
# in the room ID. In the event that some silly user does that, don't let
136142
# them paginate further.
137143
if next_batch == from_batch:
138-
return 200, {"joined": rooms}
144+
return 200, {"joined": rooms, "count": total_count}
139145

140-
return 200, {"joined": list(rooms), "next_batch": next_batch}
146+
return 200, {
147+
"joined": rooms,
148+
"next_batch": next_batch,
149+
"count": total_count,
150+
}
141151

142152

143153
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
144-
if hs.config.experimental.msc2666_enabled:
145-
UserMutualRoomsServlet(hs).register(http_server)
154+
UserMutualRoomsServlet(hs).register(http_server)

synapse/rest/client/versions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
144144
# Implements additional endpoints as described in MSC2432
145145
"org.matrix.msc2432": True,
146146
# Implements additional endpoints as described in MSC2666
147-
"uk.half-shot.msc2666.query_mutual_rooms": self.config.experimental.msc2666_enabled,
147+
"uk.half-shot.msc2666.query_mutual_rooms.stable": True,
148148
# Whether new rooms will be set to encrypted or not (based on presets).
149149
"io.element.e2ee_forced.public": self.e2ee_forced_public,
150150
"io.element.e2ee_forced.private": self.e2ee_forced_private,

synapse/types/__init__.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def is_valid(cls: type[DS], s: str) -> bool:
343343
# possible for invalid data to exist in room-state, etc.
344344
parse_and_validate_server_name(obj.domain)
345345
return True
346-
except Exception:
346+
except (SynapseError, ValueError):
347347
return False
348348

349349
__repr__ = to_string
@@ -355,6 +355,29 @@ class UserID(DomainSpecificString):
355355

356356
SIGIL = "@"
357357

358+
@classmethod
359+
def is_valid_strict(cls, s: str) -> bool:
360+
"""
361+
Parses the input string and attempts to ensure it is a valid and compliant user
362+
ID according to https://spec.matrix.org/v1.17/appendices/#historical-user-ids.
363+
364+
This should be used with care: there are existing non-compliant user IDs in the
365+
wild with empty or non-ASCII localparts, which will be rejected by this method.
366+
"""
367+
if len(s.encode("utf-8")) > 255:
368+
return False
369+
try:
370+
obj = cls.from_string(s)
371+
if not is_compliant_user_id_localpart(obj.localpart):
372+
return False
373+
# Apply additional validation to the domain. This is only done
374+
# during is_valid (and not part of from_string) since it is
375+
# possible for invalid data to exist in room-state, etc.
376+
parse_and_validate_server_name(obj.domain)
377+
return True
378+
except (SynapseError, ValueError):
379+
return False
380+
358381

359382
@attr.s(slots=True, frozen=True, repr=False)
360383
class RoomAlias(DomainSpecificString):
@@ -453,19 +476,46 @@ class EventID(DomainSpecificString):
453476

454477

455478
def contains_invalid_mxid_characters(localpart: str) -> bool:
456-
"""Check for characters not allowed in an mxid or groupid localpart
479+
"""
480+
Check for characters not allowed in a modern user ID localpart.
481+
482+
This is primarily used for new registrations and MUST NOT be used to validate
483+
existing user IDs, as there are real users whose user IDs don't follow this
484+
character set.
485+
486+
See https://spec.matrix.org/v1.17/appendices/#user-identifiers
457487
458488
Args:
459489
localpart: the localpart to be checked
460-
use_extended_character_set: True to use the extended allowed characters
461-
from MSC4009.
462490
463491
Returns:
464492
True if there are any naughty characters
465493
"""
466494
return any(c not in MXID_LOCALPART_ALLOWED_CHARACTERS for c in localpart)
467495

468496

497+
def is_compliant_user_id_localpart(localpart: str) -> bool:
498+
"""
499+
Validates that the given user ID localpart is within the "compliant" range,
500+
i.e. not empty and all characters are between U+0021 and U+007E inclusive.
501+
See https://spec.matrix.org/v1.17/appendices/#historical-user-ids
502+
503+
To check if a localpart is non-historical, use contains_invalid_mxid_characters instead.
504+
505+
This should be used with care: there are existing non-compliant user IDs in the
506+
wild with empty or non-ASCII localparts, which will be rejected by this method.
507+
508+
Args:
509+
localpart: the localpart to be checked
510+
511+
Returns:
512+
True if the localpart is compliant, False otherwise
513+
"""
514+
if not localpart:
515+
return False
516+
return all(0x21 <= ord(c) <= 0x7E for c in localpart)
517+
518+
469519
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
470520

471521
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid

tests/rest/client/test_mutual_rooms.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,6 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
4343
mutual_rooms.register_servlets,
4444
]
4545

46-
def default_config(self) -> dict:
47-
config = super().default_config()
48-
experimental = config.setdefault("experimental_features", {})
49-
experimental.setdefault("msc2666_enabled", True)
50-
return config
51-
5246
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
5347
config = self.default_config()
5448
return self.setup_test_homeserver(config=config)
@@ -62,27 +56,12 @@ def _get_mutual_rooms(
6256
) -> FakeChannel:
6357
return self.make_request(
6458
"GET",
65-
"/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms"
59+
"/_matrix/client/v1/mutual_rooms"
6660
f"?user_id={quote(other_user)}"
6761
+ (f"&from={quote(since_token)}" if since_token else ""),
6862
access_token=token,
6963
)
7064

71-
@unittest.override_config({"experimental_features": {"msc2666_enabled": False}})
72-
def test_mutual_rooms_no_experimental_flag(self) -> None:
73-
"""
74-
The endpoint should 404 if the experimental flag is not enabled.
75-
"""
76-
# Register a user.
77-
u1 = self.register_user("user1", "pass")
78-
u1_token = self.login(u1, "pass")
79-
80-
# Check that we're unable to query the endpoint due to the endpoint
81-
# being unrecognised.
82-
channel = self._get_mutual_rooms(u1_token, "@not-used:test")
83-
self.assertEqual(404, channel.code, channel.result)
84-
self.assertEqual("M_UNRECOGNIZED", channel.json_body["errcode"], channel.result)
85-
8665
def test_shared_room_list_public(self) -> None:
8766
"""
8867
A room should show up in the shared list of rooms between two users
@@ -129,6 +108,7 @@ def _check_mutual_rooms_with(
129108
channel = self._get_mutual_rooms(u1_token, u2)
130109
self.assertEqual(200, channel.code, channel.result)
131110
self.assertEqual(len(channel.json_body["joined"]), 1)
111+
self.assertEqual(channel.json_body["count"], 1)
132112
self.assertEqual(channel.json_body["joined"][0], room_id_one)
133113

134114
# Create another room and invite user2 to it
@@ -142,6 +122,7 @@ def _check_mutual_rooms_with(
142122
channel = self._get_mutual_rooms(u1_token, u2)
143123
self.assertEqual(200, channel.code, channel.result)
144124
self.assertEqual(len(channel.json_body["joined"]), 2)
125+
self.assertEqual(channel.json_body["count"], 2)
145126
for room_id_id in channel.json_body["joined"]:
146127
self.assertIn(room_id_id, [room_id_one, room_id_two])
147128

@@ -167,11 +148,13 @@ def test_shared_room_list_pagination_two_pages(self) -> None:
167148
channel = self._get_mutual_rooms(u1_token, u2)
168149
self.assertEqual(200, channel.code, channel.result)
169150
self.assertEqual(channel.json_body["joined"], room_ids[0:10])
151+
self.assertEqual(channel.json_body["count"], 15)
170152
self.assertIn("next_batch", channel.json_body)
171153

172154
channel = self._get_mutual_rooms(u1_token, u2, channel.json_body["next_batch"])
173155
self.assertEqual(200, channel.code, channel.result)
174156
self.assertEqual(channel.json_body["joined"], room_ids[10:20])
157+
self.assertEqual(channel.json_body["count"], 15)
175158
self.assertNotIn("next_batch", channel.json_body)
176159

177160
def test_shared_room_list_pagination_one_page(self) -> None:
@@ -180,6 +163,7 @@ def test_shared_room_list_pagination_one_page(self) -> None:
180163
channel = self._get_mutual_rooms(u1_token, u2)
181164
self.assertEqual(200, channel.code, channel.result)
182165
self.assertEqual(channel.json_body["joined"], room_ids)
166+
self.assertEqual(channel.json_body["count"], 10)
183167
self.assertNotIn("next_batch", channel.json_body)
184168

185169
def test_shared_room_list_pagination_invalid_token(self) -> None:
@@ -209,6 +193,7 @@ def test_shared_room_list_after_leave(self) -> None:
209193
channel = self._get_mutual_rooms(u1_token, u2)
210194
self.assertEqual(200, channel.code, channel.result)
211195
self.assertEqual(len(channel.json_body["joined"]), 1)
196+
self.assertEqual(channel.json_body["count"], 1)
212197
self.assertEqual(channel.json_body["joined"][0], room)
213198

214199
self.helper.leave(room, user=u1, tok=u1_token)
@@ -217,11 +202,13 @@ def test_shared_room_list_after_leave(self) -> None:
217202
channel = self._get_mutual_rooms(u1_token, u2)
218203
self.assertEqual(200, channel.code, channel.result)
219204
self.assertEqual(len(channel.json_body["joined"]), 0)
205+
self.assertEqual(channel.json_body["count"], 0)
220206

221207
# Check user2's view of shared rooms with user1
222208
channel = self._get_mutual_rooms(u2_token, u1)
223209
self.assertEqual(200, channel.code, channel.result)
224210
self.assertEqual(len(channel.json_body["joined"]), 0)
211+
self.assertEqual(channel.json_body["count"], 0)
225212

226213
def test_shared_room_list_nonexistent_user(self) -> None:
227214
u1 = self.register_user("user1", "pass")
@@ -232,4 +219,27 @@ def test_shared_room_list_nonexistent_user(self) -> None:
232219
channel = self._get_mutual_rooms(u1_token, "@meow:example.com")
233220
self.assertEqual(200, channel.code, channel.result)
234221
self.assertEqual(len(channel.json_body["joined"]), 0)
222+
self.assertEqual(channel.json_body["count"], 0)
235223
self.assertNotIn("next_batch", channel.json_body)
224+
225+
def test_shared_room_list_invalid_user(self) -> None:
226+
u1 = self.register_user("user1", "pass")
227+
u1_token = self.login(u1, "pass")
228+
229+
channel = self._get_mutual_rooms(u1_token, "@:example.com")
230+
self.assertEqual(400, channel.code, channel.result)
231+
self.assertEqual(
232+
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
233+
)
234+
235+
channel = self._get_mutual_rooms(u1_token, "@" + "a" * 255 + ":example.com")
236+
self.assertEqual(400, channel.code, channel.result)
237+
self.assertEqual(
238+
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
239+
)
240+
241+
channel = self._get_mutual_rooms(u1_token, "@🐈️:example.com")
242+
self.assertEqual(400, channel.code, channel.result)
243+
self.assertEqual(
244+
"M_INVALID_PARAM", channel.json_body["errcode"], channel.result
245+
)

0 commit comments

Comments
 (0)