Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit bb9081e

Browse files
committed
Use a flat list of algorithms instead of a map.
1 parent 2d6b903 commit bb9081e

File tree

4 files changed

+39
-21
lines changed

4 files changed

+39
-21
lines changed

synapse/appservice/api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,10 @@ async def claim_client_keys(
471471

472472
# Create the expected payload shape.
473473
body: Dict[str, Dict[str, List[str]]] = {}
474-
for user_id, device, algorithm, _count in query:
475-
# Note that only a single OTK can be claimed this way.
476-
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
474+
for user_id, device, algorithm, count in query:
475+
body.setdefault(user_id, {}).setdefault(device, []).extend(
476+
[algorithm] * count
477+
)
477478

478479
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
479480
try:

synapse/federation/federation_client.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ async def query_user_devices(
237237
async def claim_client_keys(
238238
self,
239239
destination: str,
240-
content: Dict[str, Dict[str, Dict[str, int]]],
240+
query: Dict[str, Dict[str, Dict[str, int]]],
241241
timeout: Optional[int],
242242
) -> JsonDict:
243243
"""Claims one-time keys for a device hosted on a remote server.
@@ -251,24 +251,33 @@ async def claim_client_keys(
251251
"""
252252
sent_queries_counter.labels("client_one_time_keys").inc()
253253

254-
# Convert the query with counts into a legacy query and check if attempting
255-
# to claim more than 1 OTK.
256-
legacy_content: Dict[str, Dict[str, str]] = {}
254+
# Convert the query with counts into a stable and unstable query and check
255+
# if attempting to claim more than 1 OTK.
256+
content: Dict[str, Dict[str, str]] = {}
257+
unstable_content: Dict[str, Dict[str, List[str]]] = {}
257258
use_unstable = False
258-
for user_id, one_time_keys in content.items():
259+
for user_id, one_time_keys in query.items():
259260
for device_id, algorithms in one_time_keys.items():
260261
if any(count > 1 for count in algorithms.values()):
261262
use_unstable = True
262263
if algorithms:
263-
# Choose the first algorithm only.
264-
legacy_content.setdefault(user_id, {})[device_id] = next(
265-
iter(algorithms)
264+
# Choose the first algorithm only for the stable query.
265+
content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
266+
# Flatten the map of algorithm -> count to a list repeating
267+
# each algorithm count times for the unstable query.
268+
unstable_content.setdefault(user_id, {})[device_id] = list(
269+
itertools.chain(
270+
*(
271+
itertools.repeat(algorithm, count)
272+
for algorithm, count in algorithms.items()
273+
)
274+
)
266275
)
267276

268277
if use_unstable:
269278
try:
270279
return await self.transport_layer.claim_client_keys_unstable(
271-
destination, content, timeout
280+
destination, unstable_content, timeout
272281
)
273282
except HttpResponseException as e:
274283
# If an error is received that is due to an unrecognised endpoint,
@@ -284,7 +293,7 @@ async def claim_client_keys(
284293
logger.debug("Skipping unstable claim client keys API")
285294

286295
return await self.transport_layer.claim_client_keys(
287-
destination, legacy_content, timeout
296+
destination, content, timeout
288297
)
289298

290299
@trace

synapse/federation/transport/server/federation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
from collections import Counter
1516
from typing import (
1617
TYPE_CHECKING,
1718
Dict,
@@ -577,7 +578,7 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
577578
async def on_POST(
578579
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
579580
) -> Tuple[int, JsonDict]:
580-
# Flatten the request query.
581+
# Generate a count for each algorithm, which is hard-coded to 1.
581582
key_query: List[Tuple[str, str, str, int]] = []
582583
for user_id, device_keys in content.get("one_time_keys", {}).items():
583584
for device_id, algorithm in device_keys.items():
@@ -603,11 +604,12 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
603604
async def on_POST(
604605
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
605606
) -> Tuple[int, JsonDict]:
606-
# Flatten the request query.
607+
# Generate a count for each algorithm.
607608
key_query: List[Tuple[str, str, str, int]] = []
608609
for user_id, device_keys in content.get("one_time_keys", {}).items():
609610
for device_id, algorithms in device_keys.items():
610-
for algorithm, count in algorithms.items():
611+
counts = Counter(algorithms)
612+
for algorithm, count in counts.items():
611613
key_query.append((user_id, device_id, algorithm, count))
612614

613615
response = await self.handler.on_claim_client_keys(

synapse/rest/client/keys.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
import re
19+
from collections import Counter
1920
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
2021

2122
from synapse.api.errors import InvalidAPICallError, SynapseError
@@ -290,7 +291,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
290291
timeout = parse_integer(request, "timeout", 10 * 1000)
291292
body = parse_json_object_from_request(request)
292293

293-
# Map the legacy request to the new request format.
294+
# Generate a count for each algorithm, which is hard-coded to 1.
294295
query: Dict[str, Dict[str, Dict[str, int]]] = {}
295296
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
296297
for device_id, algorithm in one_time_keys.items():
@@ -312,9 +313,8 @@ class UnstableOneTimeKeyServlet(RestServlet):
312313
{
313314
"one_time_keys": {
314315
"<user_id>": {
315-
"<device_id>": {
316-
"<algorithm>": <count>
317-
} } } }
316+
"<device_id>": ["<algorithm>", ...]
317+
} } }
318318
319319
HTTP/1.1 200 OK
320320
{
@@ -338,7 +338,13 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
338338
await self.auth.get_user_by_req(request, allow_guest=True)
339339
timeout = parse_integer(request, "timeout", 10 * 1000)
340340
body = parse_json_object_from_request(request)
341-
query = body.get("one_time_keys", {})
341+
342+
# Generate a count for each algorithm.
343+
query: Dict[str, Dict[str, Dict[str, int]]] = {}
344+
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
345+
for device_id, algorithms in one_time_keys.items():
346+
query.setdefault(user_id, {})[device_id] = Counter(algorithms)
347+
342348
result = await self.e2e_keys_handler.claim_one_time_keys(
343349
query, timeout, always_include_fallback_keys=True
344350
)

0 commit comments

Comments
 (0)