From 01ca4330bae42706a521d20f3c82969a2783c4f7 Mon Sep 17 00:00:00 2001 From: Mark Kittisopikul Date: Sat, 30 May 2026 17:27:55 -0400 Subject: [PATCH] feat: add per-key locking to MemoryStore and LocalStore set_range MemoryStore uses per-key asyncio.Lock / threading.Lock to serialise concurrent set_range / set_range_sync calls on the same key. LocalStore uses a double-checked file-lock protocol: the target file is atomically renamed to a key.__lock__ sibling (via _safe_move) to signal exclusive ownership across both threads and processes, with stale-lock recovery after 60 s. The in-process asyncio.Lock / threading.Lock guards the rename so only one coroutine/thread races to the filesystem at a time. Concurrent tests are added for both stores verifying that parallel writes to non-overlapping byte ranges produce the correct combined result. Co-Authored-By: Claude Sonnet 4.6 --- src/zarr/storage/_local.py | 85 ++++++++++++++++++++++++++++++++- src/zarr/storage/_memory.py | 17 ++++++- tests/test_store/test_local.py | 55 +++++++++++++++++++++ tests/test_store/test_memory.py | 48 +++++++++++++++++++ 4 files changed, 201 insertions(+), 4 deletions(-) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f9849a343d..a677587b28 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -6,6 +6,8 @@ import os import shutil import sys +import threading +import time import uuid from pathlib import Path from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self @@ -59,6 +61,18 @@ def _safe_move(src: Path, dst: Path) -> None: os.unlink(src) +_LOCK_POLL_INTERVAL = 0.01 # seconds between lock-file existence checks +_LOCK_STALE_TIMEOUT = 60.0 # seconds before an abandoned lock file is reclaimed + + +def _is_stale_lock(lock_path: Path) -> bool: + """Return True if lock_path either doesn't exist or is older than _LOCK_STALE_TIMEOUT.""" + try: + return time.time() - lock_path.stat().st_mtime > _LOCK_STALE_TIMEOUT + except FileNotFoundError: + return True + + @contextlib.contextmanager def _atomic_write( path: Path, @@ -118,6 +132,8 @@ class LocalStore(Store, SupportsSetRange): supports_listing: bool = True root: Path + _key_locks: dict[str, asyncio.Lock] + _key_locks_sync: dict[str, threading.Lock] def __init__(self, root: Path | str, *, read_only: bool = False) -> None: super().__init__(read_only=read_only) @@ -128,6 +144,8 @@ def __init__(self, root: Path | str, *, read_only: bool = False) -> None: f"'root' must be a string or Path instance. Got an instance of {type(root)} instead." ) self.root = root + self._key_locks = {} + self._key_locks_sync = {} def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited @@ -306,13 +324,76 @@ async def set_range(self, key: str, value: Buffer, start: int) -> None: await self._open() self._check_writable() path = self.root / key - await asyncio.to_thread(_put_range, path, value, start) + lock_path = path.with_name(path.name + ".__lock__") + in_process_lock = self._key_locks.setdefault(key, asyncio.Lock()) + + # Acquire the file lock (steps 1-5 from the concurrency plan). + while True: + # Step 1: spin-wait until no lock file is present (or it is stale). + while await asyncio.to_thread(lock_path.exists): + if await asyncio.to_thread(_is_stale_lock, lock_path): + break + await asyncio.sleep(_LOCK_POLL_INTERVAL) + + # Steps 2-5: serialise the rename under an in-process lock so that + # only one coroutine per process attempts the atomic file move at a time. + acquired = False + async with in_process_lock: + # Step 3: re-check after acquiring the in-process lock. + if not await asyncio.to_thread(lock_path.exists): + try: + # Step 4: atomic rename — raises FileExistsError if another + # process grabbed the lock between steps 3 and 4. + await asyncio.to_thread(_safe_move, path, lock_path) + acquired = True + except FileExistsError: + pass + # Step 5: in-process lock released on context exit. + + if acquired: + break + + # Step 6: perform the partial write on the lock file. + try: + await asyncio.to_thread(_put_range, lock_path, value, start) + finally: + # Steps 7-9: re-acquire in-process lock, rename lock file back, release. + async with in_process_lock: + await asyncio.to_thread(lock_path.replace, path) def set_range_sync(self, key: str, value: Buffer, start: int) -> None: self._ensure_open_sync() self._check_writable() path = self.root / key - _put_range(path, value, start) + lock_path = path.with_name(path.name + ".__lock__") + in_process_lock = self._key_locks_sync.setdefault(key, threading.Lock()) + + # Acquire the file lock (same double-checked pattern as the async path). + while True: + # Step 1: spin-wait. + while lock_path.exists(): + if _is_stale_lock(lock_path): + break + time.sleep(_LOCK_POLL_INTERVAL) + + acquired = False + with in_process_lock: + if not lock_path.exists(): + try: + _safe_move(path, lock_path) + acquired = True + except FileExistsError: + pass + + if acquired: + break + + # Partial write, then release. + try: + _put_range(lock_path, value, start) + finally: + with in_process_lock: + lock_path.replace(path) async def delete(self, key: str) -> None: """ diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 54cf300098..f6bca4afc9 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import os import threading import weakref @@ -49,6 +50,8 @@ class MemoryStore(Store, SupportsSetRange): supports_listing: bool = True _store_dict: MutableMapping[str, Buffer] + _key_locks: dict[str, asyncio.Lock] + _key_locks_sync: dict[str, threading.Lock] def __init__( self, @@ -60,6 +63,8 @@ def __init__( if store_dict is None: store_dict = {} self._store_dict = store_dict + self._key_locks = {} + self._key_locks_sync = {} def with_read_only(self, read_only: bool = False) -> MemoryStore: # docstring inherited @@ -206,13 +211,17 @@ def _set_range_impl(self, key: str, value: Buffer, start: int) -> None: async def set_range(self, key: str, value: Buffer, start: int) -> None: self._check_writable() await self._ensure_open() - self._set_range_impl(key, value, start) + lock = self._key_locks.setdefault(key, asyncio.Lock()) + async with lock: + self._set_range_impl(key, value, start) def set_range_sync(self, key: str, value: Buffer, start: int) -> None: self._check_writable() if not self._is_open: self._is_open = True - self._set_range_impl(key, value, start) + lock = self._key_locks_sync.setdefault(key, threading.Lock()) + with lock: + self._set_range_impl(key, value, start) async def list(self) -> AsyncIterator[str]: # docstring inherited @@ -729,6 +738,8 @@ def __init__(self, name: str | None = None, *, path: str = "", read_only: bool = # Get or create a managed dict from the registry self._store_dict, self._name = _managed_store_dict_registry.get_or_create(name) self.path = normalize_path(path) + self._key_locks = {} + self._key_locks_sync = {} def __str__(self) -> str: return _join_paths([f"memory://{self._name}", self.path]) @@ -764,6 +775,8 @@ def _from_managed_dict( store._store_dict = managed_dict store._name = name store.path = normalize_path(path) + store._key_locks = {} + store._key_locks_sync = {} return store def with_read_only(self, read_only: bool = False) -> ManagedMemoryStore: diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 22f17ef87e..8c6e37b8f4 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -1,8 +1,10 @@ from __future__ import annotations +import asyncio import json import pathlib import re +import threading from typing import TYPE_CHECKING import numpy as np @@ -229,6 +231,59 @@ def test_set_range_sync_not_open(self, store_not_open: LocalStore) -> None: observed = sync(self.get(store_not_open, "test/key")) assert observed.to_bytes() == b"XXAAAAAAAA" + async def test_set_range_concurrent(self, store: LocalStore) -> None: + """Concurrent set_range calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + await store.set("test/key", cpu.Buffer.from_bytes(bytes(total))) + + async def write_chunk(i: int) -> None: + data = bytes([i] * chunk_size) + await store.set_range("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + + await asyncio.gather(*[write_chunk(i) for i in range(n_writers)]) + + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + + def test_set_range_sync_concurrent(self, store: LocalStore) -> None: + """Concurrent set_range_sync calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + sync(store.set("test/key", cpu.Buffer.from_bytes(bytes(total)))) + + errors: list[Exception] = [] + + def write_chunk(i: int) -> None: + try: + data = bytes([i] * chunk_size) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=write_chunk, args=(i,)) for i in range(n_writers)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + + def test_lock_file_cleaned_up(self, store: LocalStore) -> None: + """No lock file should remain after set_range_sync completes.""" + sync(store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + lock_path = store.root / "test" / "key.__lock__" + assert not lock_path.exists() + @pytest.mark.parametrize("exclusive", [True, False]) def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None: diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 5962dcb8f2..0d0a007b4e 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio import json import re +import threading from typing import TYPE_CHECKING, Any import numpy as np @@ -193,6 +195,52 @@ def test_set_range_sync_not_open(self, store_not_open: MemoryStore) -> None: assert getattr(store_not_open, "_is_open") # noqa: B009 assert store_not_open._store_dict["test/key"].to_bytes() == b"XXAAAAAAAA" + async def test_set_range_concurrent(self, store: MemoryStore) -> None: + """Concurrent set_range calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + await store.set("test/key", cpu.Buffer.from_bytes(bytes(total))) + + async def write_chunk(i: int) -> None: + data = bytes([i] * chunk_size) + await store.set_range("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + + await asyncio.gather(*[write_chunk(i) for i in range(n_writers)]) + + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + + def test_set_range_sync_concurrent(self, store: MemoryStore) -> None: + """Concurrent set_range_sync calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + store._store_dict["test/key"] = cpu.Buffer.from_bytes(bytes(total)) + + errors: list[Exception] = [] + + def write_chunk(i: int) -> None: + try: + data = bytes([i] * chunk_size) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=write_chunk, args=(i,)) for i in range(n_writers)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + # TODO: fix this warning @pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning")