Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 83 additions & 2 deletions src/zarr/storage/_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import os
import shutil
import sys
import threading
import time
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self
Expand Down Expand Up @@ -59,6 +61,18 @@ def _safe_move(src: Path, dst: Path) -> None:
os.unlink(src)


_LOCK_POLL_INTERVAL = 0.01 # seconds between lock-file existence checks
_LOCK_STALE_TIMEOUT = 60.0 # seconds before an abandoned lock file is reclaimed


def _is_stale_lock(lock_path: Path) -> bool:
"""Return True if lock_path either doesn't exist or is older than _LOCK_STALE_TIMEOUT."""
try:
return time.time() - lock_path.stat().st_mtime > _LOCK_STALE_TIMEOUT
except FileNotFoundError:
return True


@contextlib.contextmanager
def _atomic_write(
path: Path,
Expand Down Expand Up @@ -118,6 +132,8 @@ class LocalStore(Store, SupportsSetRange):
supports_listing: bool = True

root: Path
_key_locks: dict[str, asyncio.Lock]
_key_locks_sync: dict[str, threading.Lock]

def __init__(self, root: Path | str, *, read_only: bool = False) -> None:
super().__init__(read_only=read_only)
Expand All @@ -128,6 +144,8 @@ def __init__(self, root: Path | str, *, read_only: bool = False) -> None:
f"'root' must be a string or Path instance. Got an instance of {type(root)} instead."
)
self.root = root
self._key_locks = {}
self._key_locks_sync = {}

def with_read_only(self, read_only: bool = False) -> Self:
# docstring inherited
Expand Down Expand Up @@ -306,13 +324,76 @@ async def set_range(self, key: str, value: Buffer, start: int) -> None:
await self._open()
self._check_writable()
path = self.root / key
await asyncio.to_thread(_put_range, path, value, start)
lock_path = path.with_name(path.name + ".__lock__")
in_process_lock = self._key_locks.setdefault(key, asyncio.Lock())

# Acquire the file lock (steps 1-5 from the concurrency plan).
while True:
# Step 1: spin-wait until no lock file is present (or it is stale).
while await asyncio.to_thread(lock_path.exists):
if await asyncio.to_thread(_is_stale_lock, lock_path):
break
await asyncio.sleep(_LOCK_POLL_INTERVAL)

# Steps 2-5: serialise the rename under an in-process lock so that
# only one coroutine per process attempts the atomic file move at a time.
acquired = False
async with in_process_lock:
# Step 3: re-check after acquiring the in-process lock.
if not await asyncio.to_thread(lock_path.exists):
try:
# Step 4: atomic rename — raises FileExistsError if another
# process grabbed the lock between steps 3 and 4.
await asyncio.to_thread(_safe_move, path, lock_path)
acquired = True
except FileExistsError:
pass
# Step 5: in-process lock released on context exit.

if acquired:
break

# Step 6: perform the partial write on the lock file.
try:
await asyncio.to_thread(_put_range, lock_path, value, start)
finally:
# Steps 7-9: re-acquire in-process lock, rename lock file back, release.
async with in_process_lock:
await asyncio.to_thread(lock_path.replace, path)

def set_range_sync(self, key: str, value: Buffer, start: int) -> None:
self._ensure_open_sync()
self._check_writable()
path = self.root / key
_put_range(path, value, start)
lock_path = path.with_name(path.name + ".__lock__")
in_process_lock = self._key_locks_sync.setdefault(key, threading.Lock())

# Acquire the file lock (same double-checked pattern as the async path).
while True:
# Step 1: spin-wait.
while lock_path.exists():
if _is_stale_lock(lock_path):
break
time.sleep(_LOCK_POLL_INTERVAL)

acquired = False
with in_process_lock:
if not lock_path.exists():
try:
_safe_move(path, lock_path)
acquired = True
except FileExistsError:
pass

if acquired:
break

# Partial write, then release.
try:
_put_range(lock_path, value, start)
finally:
with in_process_lock:
lock_path.replace(path)

async def delete(self, key: str) -> None:
"""
Expand Down
17 changes: 15 additions & 2 deletions src/zarr/storage/_memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import os
import threading
import weakref
Expand Down Expand Up @@ -49,6 +50,8 @@ class MemoryStore(Store, SupportsSetRange):
supports_listing: bool = True

_store_dict: MutableMapping[str, Buffer]
_key_locks: dict[str, asyncio.Lock]
_key_locks_sync: dict[str, threading.Lock]

def __init__(
self,
Expand All @@ -60,6 +63,8 @@ def __init__(
if store_dict is None:
store_dict = {}
self._store_dict = store_dict
self._key_locks = {}
self._key_locks_sync = {}

def with_read_only(self, read_only: bool = False) -> MemoryStore:
# docstring inherited
Expand Down Expand Up @@ -206,13 +211,17 @@ def _set_range_impl(self, key: str, value: Buffer, start: int) -> None:
async def set_range(self, key: str, value: Buffer, start: int) -> None:
self._check_writable()
await self._ensure_open()
self._set_range_impl(key, value, start)
lock = self._key_locks.setdefault(key, asyncio.Lock())
async with lock:
self._set_range_impl(key, value, start)

def set_range_sync(self, key: str, value: Buffer, start: int) -> None:
self._check_writable()
if not self._is_open:
self._is_open = True
self._set_range_impl(key, value, start)
lock = self._key_locks_sync.setdefault(key, threading.Lock())
with lock:
self._set_range_impl(key, value, start)

async def list(self) -> AsyncIterator[str]:
# docstring inherited
Expand Down Expand Up @@ -729,6 +738,8 @@ def __init__(self, name: str | None = None, *, path: str = "", read_only: bool =
# Get or create a managed dict from the registry
self._store_dict, self._name = _managed_store_dict_registry.get_or_create(name)
self.path = normalize_path(path)
self._key_locks = {}
self._key_locks_sync = {}

def __str__(self) -> str:
return _join_paths([f"memory://{self._name}", self.path])
Expand Down Expand Up @@ -764,6 +775,8 @@ def _from_managed_dict(
store._store_dict = managed_dict
store._name = name
store.path = normalize_path(path)
store._key_locks = {}
store._key_locks_sync = {}
return store

def with_read_only(self, read_only: bool = False) -> ManagedMemoryStore:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_store/test_local.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import asyncio
import json
import pathlib
import re
import threading
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -229,6 +231,59 @@ def test_set_range_sync_not_open(self, store_not_open: LocalStore) -> None:
observed = sync(self.get(store_not_open, "test/key"))
assert observed.to_bytes() == b"XXAAAAAAAA"

async def test_set_range_concurrent(self, store: LocalStore) -> None:
"""Concurrent set_range calls to non-overlapping ranges should not corrupt data."""
n_writers = 10
chunk_size = 10
total = n_writers * chunk_size
await store.set("test/key", cpu.Buffer.from_bytes(bytes(total)))

async def write_chunk(i: int) -> None:
data = bytes([i] * chunk_size)
await store.set_range("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size)

await asyncio.gather(*[write_chunk(i) for i in range(n_writers)])

result = await store.get("test/key", prototype=cpu.buffer_prototype)
assert result is not None
expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)])
assert result.to_bytes() == expected

def test_set_range_sync_concurrent(self, store: LocalStore) -> None:
"""Concurrent set_range_sync calls to non-overlapping ranges should not corrupt data."""
n_writers = 10
chunk_size = 10
total = n_writers * chunk_size
sync(store.set("test/key", cpu.Buffer.from_bytes(bytes(total))))

errors: list[Exception] = []

def write_chunk(i: int) -> None:
try:
data = bytes([i] * chunk_size)
store.set_range_sync("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size)
except Exception as e:
errors.append(e)

threads = [threading.Thread(target=write_chunk, args=(i,)) for i in range(n_writers)]
for t in threads:
t.start()
for t in threads:
t.join()

assert not errors
result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype)
assert result is not None
expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)])
assert result.to_bytes() == expected

def test_lock_file_cleaned_up(self, store: LocalStore) -> None:
"""No lock file should remain after set_range_sync completes."""
sync(store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")))
store.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0)
lock_path = store.root / "test" / "key.__lock__"
assert not lock_path.exists()


@pytest.mark.parametrize("exclusive", [True, False])
def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None:
Expand Down
48 changes: 48 additions & 0 deletions tests/test_store/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import asyncio
import json
import re
import threading
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -193,6 +195,52 @@ def test_set_range_sync_not_open(self, store_not_open: MemoryStore) -> None:
assert getattr(store_not_open, "_is_open") # noqa: B009
assert store_not_open._store_dict["test/key"].to_bytes() == b"XXAAAAAAAA"

async def test_set_range_concurrent(self, store: MemoryStore) -> None:
"""Concurrent set_range calls to non-overlapping ranges should not corrupt data."""
n_writers = 10
chunk_size = 10
total = n_writers * chunk_size
await store.set("test/key", cpu.Buffer.from_bytes(bytes(total)))

async def write_chunk(i: int) -> None:
data = bytes([i] * chunk_size)
await store.set_range("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size)

await asyncio.gather(*[write_chunk(i) for i in range(n_writers)])

result = await store.get("test/key", prototype=cpu.buffer_prototype)
assert result is not None
expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)])
assert result.to_bytes() == expected

def test_set_range_sync_concurrent(self, store: MemoryStore) -> None:
"""Concurrent set_range_sync calls to non-overlapping ranges should not corrupt data."""
n_writers = 10
chunk_size = 10
total = n_writers * chunk_size
store._store_dict["test/key"] = cpu.Buffer.from_bytes(bytes(total))

errors: list[Exception] = []

def write_chunk(i: int) -> None:
try:
data = bytes([i] * chunk_size)
store.set_range_sync("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size)
except Exception as e:
errors.append(e)

threads = [threading.Thread(target=write_chunk, args=(i,)) for i in range(n_writers)]
for t in threads:
t.start()
for t in threads:
t.join()

assert not errors
result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype)
assert result is not None
expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)])
assert result.to_bytes() == expected


# TODO: fix this warning
@pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning")
Expand Down