diff --git a/src/infuse_iot/database.py b/src/infuse_iot/database.py index ad6466c..aec41e7 100644 --- a/src/infuse_iot/database.py +++ b/src/infuse_iot/database.py @@ -4,6 +4,8 @@ import binascii import pathlib import shelve +import threading +from contextlib import contextmanager from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x25519 @@ -77,6 +79,7 @@ def __init__( self._local_root: x25519.X25519PrivateKey | None = None self._local_root_public: bytes | None = None self._cache_path = cache_path + self._cache_lock = threading.Lock() if local_root: with local_root.open() as f: private_key = serialization.load_pem_private_key(f.read().encode("utf-8"), password=None) @@ -130,18 +133,23 @@ def observe_secondary_remote_public_key(self, infuse_id: int, secondary_pub_key: dev.secondary_device_key_id = binascii.crc32(self._local_root_public + dev.device_public_key) & 0xFFFFFF dev.local_shared_key = self._local_root.exchange(device_public_key) + @contextmanager + def _with_cache(self): + with self._cache_lock as _lock, shelve.open(str(self._cache_path)) as cache: + yield cache + def _update_cache(self, infuse_id: int, device_pub_key: bytes, shared_key: bytes): if self._cache_path is None: return infuse_id_str = f"{infuse_id:016x}" - with shelve.open(str(self._cache_path)) as cache: + with self._with_cache() as cache: cache[infuse_id_str] = {"public_key": device_pub_key, "shared_key": shared_key} def _from_cache(self, infuse_id: int, device_pub_key: bytes) -> bytes | None: if self._cache_path is None: return None infuse_id_str = f"{infuse_id:016x}" - with shelve.open(str(self._cache_path)) as cache: + with self._with_cache() as cache: state = cache.get(infuse_id_str, None) if state is None or state["public_key"] != device_pub_key: return None