-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Modify StoreKeyFetcher to read from server_keys_json. #15417
Changes from all commits
1270a5d
9d2c890
400ba67
de96650
6679a5e
bbef2a6
7f676a4
7ab918c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Fix a long-standing bug where cached key results which were directly fetched would not be properly re-used. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,10 +14,12 @@ | |
| # limitations under the License. | ||
|
|
||
| import itertools | ||
| import json | ||
| import logging | ||
| from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple | ||
|
|
||
| from signedjson.key import decode_verify_key_bytes | ||
| from unpaddedbase64 import decode_base64 | ||
|
|
||
| from synapse.storage._base import SQLBaseStore | ||
| from synapse.storage.database import LoggingTransaction | ||
|
|
@@ -36,15 +38,16 @@ class KeyStore(SQLBaseStore): | |
| """Persistence for signature verification keys""" | ||
|
|
||
| @cached() | ||
| def _get_server_verify_key( | ||
| def _get_server_signature_key( | ||
| self, server_name_and_key_id: Tuple[str, str] | ||
| ) -> FetchKeyResult: | ||
| raise NotImplementedError() | ||
|
|
||
| @cachedList( | ||
| cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" | ||
| cached_method_name="_get_server_signature_key", | ||
| list_name="server_name_and_key_ids", | ||
| ) | ||
| async def get_server_verify_keys( | ||
| async def get_server_signature_keys( | ||
| self, server_name_and_key_ids: Iterable[Tuple[str, str]] | ||
| ) -> Dict[Tuple[str, str], FetchKeyResult]: | ||
| """ | ||
|
|
@@ -62,10 +65,12 @@ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: | |
| """Processes a batch of keys to fetch, and adds the result to `keys`.""" | ||
|
|
||
| # batch_iter always returns tuples so it's safe to do len(batch) | ||
| sql = ( | ||
| "SELECT server_name, key_id, verify_key, ts_valid_until_ms " | ||
| "FROM server_signature_keys WHERE 1=0" | ||
| ) + " OR (server_name=? AND key_id=?)" * len(batch) | ||
| sql = """ | ||
| SELECT server_name, key_id, verify_key, ts_valid_until_ms | ||
| FROM server_signature_keys WHERE 1=0 | ||
| """ + " OR (server_name=? AND key_id=?)" * len( | ||
| batch | ||
| ) | ||
|
|
||
| txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) | ||
|
|
||
|
|
@@ -89,9 +94,9 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: | |
| _get_keys(txn, batch) | ||
| return keys | ||
|
|
||
| return await self.db_pool.runInteraction("get_server_verify_keys", _txn) | ||
| return await self.db_pool.runInteraction("get_server_signature_keys", _txn) | ||
|
|
||
| async def store_server_verify_keys( | ||
| async def store_server_signature_keys( | ||
| self, | ||
| from_server: str, | ||
| ts_added_ms: int, | ||
|
|
@@ -119,7 +124,7 @@ async def store_server_verify_keys( | |
| ) | ||
| ) | ||
| # invalidate takes a tuple corresponding to the params of | ||
| # _get_server_verify_key. _get_server_verify_key only takes one | ||
| # _get_server_signature_key. _get_server_signature_key only takes one | ||
| # param, which is itself the 2-tuple (server_name, key_id). | ||
| invalidations.append((server_name, key_id)) | ||
|
|
||
|
|
@@ -134,10 +139,10 @@ async def store_server_verify_keys( | |
| "verify_key", | ||
| ), | ||
| value_values=value_values, | ||
| desc="store_server_verify_keys", | ||
| desc="store_server_signature_keys", | ||
| ) | ||
|
|
||
| invalidate = self._get_server_verify_key.invalidate | ||
| invalidate = self._get_server_signature_key.invalidate | ||
| for i in invalidations: | ||
| invalidate((i,)) | ||
|
|
||
|
|
@@ -180,16 +185,86 @@ async def store_server_keys_json( | |
| desc="store_server_keys_json", | ||
| ) | ||
|
|
||
| # invalidate takes a tuple corresponding to the params of | ||
| # _get_server_keys_json. _get_server_keys_json only takes one | ||
| # param, which is itself the 2-tuple (server_name, key_id). | ||
| self._get_server_keys_json.invalidate((((server_name, key_id),))) | ||
|
|
||
| @cached() | ||
| def _get_server_keys_json( | ||
| self, server_name_and_key_id: Tuple[str, str] | ||
| ) -> FetchKeyResult: | ||
| raise NotImplementedError() | ||
|
|
||
| @cachedList( | ||
| cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids" | ||
| ) | ||
| async def get_server_keys_json( | ||
| self, server_name_and_key_ids: Iterable[Tuple[str, str]] | ||
| ) -> Dict[Tuple[str, str], FetchKeyResult]: | ||
| """ | ||
| Args: | ||
| server_name_and_key_ids: | ||
| iterable of (server_name, key-id) tuples to fetch keys for | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be a Collection rather than an Iterable? I thought we try to avoid passing iterables to DB queries because they might be exhausted when we come to retry them? (Or is this an Iterable versus Iterator thing?)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe? Don't we pass iterables in like everywhere?!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #11569 is what I had in mind. I'm happy for this to land as-is (since it's no worse and should stop trusted key servers from spamming hosts). Though I would like to better understand if Iterables are still a problem that we should worry about.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| Returns: | ||
| A map from (server_name, key_id) -> FetchKeyResult, or None if the | ||
| key is unknown | ||
| """ | ||
| keys = {} | ||
|
|
||
| def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: | ||
| """Processes a batch of keys to fetch, and adds the result to `keys`.""" | ||
|
|
||
| # batch_iter always returns tuples so it's safe to do len(batch) | ||
| sql = """ | ||
| SELECT server_name, key_id, key_json, ts_valid_until_ms | ||
| FROM server_keys_json WHERE 1=0 | ||
| """ + " OR (server_name=? AND key_id=?)" * len( | ||
| batch | ||
| ) | ||
|
|
||
| txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) | ||
|
|
||
| for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn: | ||
| if ts_valid_until_ms is None: | ||
| # Old keys may be stored with a ts_valid_until_ms of null, | ||
| # in which case we treat this as if it was set to `0`, i.e. | ||
| # it won't match key requests that define a minimum | ||
| # `ts_valid_until_ms`. | ||
| ts_valid_until_ms = 0 | ||
|
|
||
| # The entire signed JSON response is stored in server_keys_json, | ||
| # fetch out the bits needed. | ||
| key_json = json.loads(bytes(key_json_bytes)) | ||
DMRobertson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| key_base64 = key_json["verify_keys"][key_id]["key"] | ||
|
|
||
| keys[(server_name, key_id)] = FetchKeyResult( | ||
| verify_key=decode_verify_key_bytes( | ||
| key_id, decode_base64(key_base64) | ||
| ), | ||
| valid_until_ts=ts_valid_until_ms, | ||
| ) | ||
|
|
||
| def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: | ||
| for batch in batch_iter(server_name_and_key_ids, 50): | ||
| _get_keys(txn, batch) | ||
| return keys | ||
|
|
||
| return await self.db_pool.runInteraction("get_server_keys_json", _txn) | ||
|
|
||
| async def get_server_keys_json_for_remote( | ||
| self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] | ||
| ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: | ||
| """Retrieve the key json for a list of server_keys and key ids. | ||
| If no keys are found for a given server, key_id and source then | ||
| that server, key_id, and source triplet entry will be an empty list. | ||
| The JSON is returned as a byte array so that it can be efficiently | ||
clokep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| used in an HTTP response. | ||
|
|
||
| Args: | ||
| server_keys: List of (server_name, key_id, source) triplets. | ||
|
|
||
| Returns: | ||
| A mapping from (server_name, key_id, source) triplets to a list of dicts | ||
| """ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.