From da232aee33ee22f57059c13e3b5c7d1822655e53 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 15 Apr 2026 09:29:18 +0200 Subject: [PATCH 01/10] feat: add SupportsSetRange protocol and store implementations Add SupportsSetRange protocol for stores that support writing to a byte range within an existing value (set_range/set_range_sync). Implement in MemoryStore and LocalStore, both explicitly subclassing the protocol. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/zarr/abc/store.py | 10 ++++++++++ src/zarr/storage/_local.py | 23 ++++++++++++++++++++++- src/zarr/storage/_memory.py | 24 ++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 600df17ee5..9ec7c4cecc 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -22,6 +22,7 @@ "Store", "SupportsDeleteSync", "SupportsGetSync", + "SupportsSetRange", "SupportsSetSync", "SupportsSyncStore", "set_or_delete", @@ -709,6 +710,15 @@ async def delete(self) -> None: ... async def set_if_not_exists(self, default: Buffer) -> None: ... +@runtime_checkable +class SupportsSetRange(Protocol): + """Protocol for stores that support writing to a byte range within an existing value.""" + + async def set_range(self, key: str, value: Buffer, start: int) -> None: ... + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: ... + + @runtime_checkable class SupportsGetSync(Protocol): def get_sync( diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 96f1e61746..a0eda303e1 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -16,6 +16,7 @@ RangeByteRequest, Store, SuffixByteRequest, + SupportsSetRange, ) from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype @@ -77,6 +78,13 @@ def _atomic_write( raise +def _put_range(path: Path, value: Buffer, start: int) -> None: + """Write bytes at a specific offset within an existing file.""" + with path.open("r+b") as f: + f.seek(start) + f.write(value.as_numpy_array().tobytes()) + + def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: path.parent.mkdir(parents=True, exist_ok=True) # write takes any object supporting the buffer protocol @@ -85,7 +93,7 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: return f.write(view) -class LocalStore(Store): +class LocalStore(Store, SupportsSetRange): """ Store for the local file system. @@ -292,6 +300,19 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: path = self.root / key await asyncio.to_thread(_put, path, value, exclusive=exclusive) + async def set_range(self, key: str, value: Buffer, start: int) -> None: + if not self._is_open: + await self._open() + self._check_writable() + path = self.root / key + await asyncio.to_thread(_put_range, path, value, start) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + self._ensure_open_sync() + self._check_writable() + path = self.root / key + _put_range(path, value, start) + async def delete(self, key: str) -> None: """ Remove a key from the store. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 1194894b9d..cb773ae30a 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -3,7 +3,7 @@ from logging import getLogger from typing import TYPE_CHECKING, Any, Self -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import ByteRequest, Store, SupportsSetRange from zarr.core.buffer import Buffer, gpu from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map @@ -18,7 +18,7 @@ logger = getLogger(__name__) -class MemoryStore(Store): +class MemoryStore(Store, SupportsSetRange): """ Store for local memory. @@ -186,6 +186,26 @@ async def delete(self, key: str) -> None: except KeyError: logger.debug("Key %s does not exist.", key) + def _set_range_impl(self, key: str, value: Buffer, start: int) -> None: + buf = self._store_dict[key] + target = buf.as_numpy_array() + if not target.flags.writeable: + target = target.copy() + self._store_dict[key] = buf.__class__(target) + source = value.as_numpy_array() + target[start : start + len(source)] = source + + async def set_range(self, key: str, value: Buffer, start: int) -> None: + self._check_writable() + await self._ensure_open() + self._set_range_impl(key, value, start) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + self._check_writable() + if not self._is_open: + self._is_open = True + self._set_range_impl(key, value, start) + async def list(self) -> AsyncIterator[str]: # docstring inherited for key in self._store_dict: From 579ff1642818158e08cad095823c1a5672521888 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 15 Apr 2026 10:48:55 +0200 Subject: [PATCH 02/10] test: add tests for SupportsSetRange on MemoryStore and LocalStore Tests cover isinstance check, async set_range, sync set_range_sync, and edge case (writing at end of value). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/zarr/abc/store.py | 10 ++++++- tests/test_store/test_local.py | 49 ++++++++++++++++++++++++++++++++ tests/test_store/test_memory.py | 50 +++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 1 deletion(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 9ec7c4cecc..c33651f016 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -712,7 +712,15 @@ async def set_if_not_exists(self, default: Buffer) -> None: ... @runtime_checkable class SupportsSetRange(Protocol): - """Protocol for stores that support writing to a byte range within an existing value.""" + """Protocol for stores that support writing to a byte range within an existing value. + + Overwrites ``len(value)`` bytes starting at byte offset ``start`` within the + existing stored value for ``key``. The key must already exist and the write + must fit within the existing value (i.e., ``start + len(value) <= len(existing)``). + + Behavior when the write extends past the end of the existing value is + implementation-specific and should not be relied upon. + """ async def set_range(self, key: str, value: Buffer, start: int) -> None: ... diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index bdc9b48121..0712cd1bca 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -10,6 +10,7 @@ import zarr from zarr import create_array +from zarr.abc.store import SupportsSetRange from zarr.core.buffer import Buffer, cpu from zarr.core.sync import sync from zarr.storage import LocalStore @@ -162,6 +163,54 @@ def test_get_json_sync_with_prototype_none( result = store._get_json_sync(key, prototype=buffer_cls) assert result == data + def test_supports_set_range(self, store: LocalStore) -> None: + """LocalStore should implement SupportsSetRange.""" + assert isinstance(store, SupportsSetRange) + + @pytest.mark.parametrize( + ("start", "patch", "expected"), + [ + (0, b"XX", b"XXAAAAAAAA"), + (3, b"XX", b"AAAXXAAAAA"), + (8, b"XX", b"AAAAAAAAXX"), + (0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"), + (5, b"B", b"AAAAABAAAA"), + (0, b"BCDE", b"BCDEAAAAAA"), + ], + ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"], + ) + async def test_set_range( + self, store: LocalStore, start: int, patch: bytes, expected: bytes + ) -> None: + """set_range should overwrite bytes at the given offset.""" + await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start) + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == expected + + @pytest.mark.parametrize( + ("start", "patch", "expected"), + [ + (0, b"XX", b"XXAAAAAAAA"), + (3, b"XX", b"AAAXXAAAAA"), + (8, b"XX", b"AAAAAAAAXX"), + (0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"), + (5, b"B", b"AAAAABAAAA"), + (0, b"BCDE", b"BCDEAAAAAA"), + ], + ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"], + ) + def test_set_range_sync( + self, store: LocalStore, start: int, patch: bytes, expected: bytes + ) -> None: + """set_range_sync should overwrite bytes at the given offset.""" + sync(store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(patch), start=start) + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == expected + @pytest.mark.parametrize("exclusive", [True, False]) def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None: diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 03c8b24271..d2554b411f 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -9,6 +9,7 @@ import pytest import zarr +from zarr.abc.store import SupportsSetRange from zarr.core.buffer import Buffer, cpu, gpu from zarr.core.sync import sync from zarr.errors import ZarrUserWarning @@ -127,6 +128,55 @@ def test_get_json_sync_with_prototype_none( result = store._get_json_sync(key, prototype=buffer_cls) assert result == data + def test_supports_set_range(self, store: MemoryStore) -> None: + """MemoryStore should implement SupportsSetRange.""" + assert isinstance(store, SupportsSetRange) + + @pytest.mark.parametrize( + ("start", "patch", "expected"), + [ + (0, b"XX", b"XXAAAAAAAA"), + (3, b"XX", b"AAAXXAAAAA"), + (8, b"XX", b"AAAAAAAAXX"), + (0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"), + (5, b"B", b"AAAAABAAAA"), + (0, b"BCDE", b"BCDEAAAAAA"), + ], + ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"], + ) + async def test_set_range( + self, store: MemoryStore, start: int, patch: bytes, expected: bytes + ) -> None: + """set_range should overwrite bytes at the given offset.""" + await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start) + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == expected + + @pytest.mark.parametrize( + ("start", "patch", "expected"), + [ + (0, b"XX", b"XXAAAAAAAA"), + (3, b"XX", b"AAAXXAAAAA"), + (8, b"XX", b"AAAAAAAAXX"), + (0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"), + (5, b"B", b"AAAAABAAAA"), + (0, b"BCDE", b"BCDEAAAAAA"), + ], + ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"], + ) + def test_set_range_sync( + self, store: MemoryStore, start: int, patch: bytes, expected: bytes + ) -> None: + """set_range_sync should overwrite bytes at the given offset.""" + store._is_open = True + store._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA") + store.set_range_sync("test/key", cpu.Buffer.from_bytes(patch), start=start) + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == expected + # TODO: fix this warning @pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning") From 2b9d80449d3a0c2954025ecfda1e5dc0800aea3b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 15 Apr 2026 11:06:22 +0200 Subject: [PATCH 03/10] docs: changelog --- changes/3907.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/3907.feature.md diff --git a/changes/3907.feature.md b/changes/3907.feature.md new file mode 100644 index 0000000000..66b908d305 --- /dev/null +++ b/changes/3907.feature.md @@ -0,0 +1 @@ +Add protocols for stores that support byte-range-writes. This is necessary to support in-place writes of sharded arrays. \ No newline at end of file From 5c26a08e4c5d709a8effa88745b5faa2100cb32c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 1 May 2026 08:15:46 -0400 Subject: [PATCH 04/10] test: add tests for open / not open --- tests/test_store/test_local.py | 18 ++++++++++++++++++ tests/test_store/test_memory.py | 18 +++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 0712cd1bca..0200f99c75 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -211,6 +211,24 @@ def test_set_range_sync( assert result is not None assert result.to_bytes() == expected + async def test_set_range_not_open(self, store_not_open: LocalStore) -> None: + """set_range auto-opens a closed store.""" + assert not store_not_open._is_open + await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + assert store_not_open._is_open + observed = await self.get(store_not_open, "test/key") + assert observed.to_bytes() == b"XXAAAAAAAA" + + def test_set_range_sync_not_open(self, store_not_open: LocalStore) -> None: + """set_range_sync auto-opens a closed store.""" + assert not store_not_open._is_open + sync(self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))) + store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + assert store_not_open._is_open + observed = sync(self.get(store_not_open, "test/key")) + assert observed.to_bytes() == b"XXAAAAAAAA" + @pytest.mark.parametrize("exclusive", [True, False]) def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None: diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 8320e96b43..dc63644c2c 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -170,13 +170,29 @@ def test_set_range_sync( self, store: MemoryStore, start: int, patch: bytes, expected: bytes ) -> None: """set_range_sync should overwrite bytes at the given offset.""" - store._is_open = True store._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA") store.set_range_sync("test/key", cpu.Buffer.from_bytes(patch), start=start) result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) assert result is not None assert result.to_bytes() == expected + async def test_set_range_not_open(self, store_not_open: MemoryStore) -> None: + """set_range auto-opens a closed store.""" + assert not store_not_open._is_open + await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + assert store_not_open._is_open + observed = await self.get(store_not_open, "test/key") + assert observed.to_bytes() == b"XXAAAAAAAA" + + def test_set_range_sync_not_open(self, store_not_open: MemoryStore) -> None: + """set_range_sync auto-opens a closed store.""" + assert not store_not_open._is_open + store_not_open._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA") + store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + assert store_not_open._is_open + assert store_not_open._store_dict["test/key"].to_bytes() == b"XXAAAAAAAA" + # TODO: fix this warning @pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning") From 91590dd28f34e628bd32652a6d8e3e3334527ed3 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 1 May 2026 08:26:11 -0400 Subject: [PATCH 05/10] fixup --- tests/test_store/test_local.py | 4 ++-- tests/test_store/test_memory.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 0200f99c75..d28fd71211 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -216,7 +216,7 @@ async def test_set_range_not_open(self, store_not_open: LocalStore) -> None: assert not store_not_open._is_open await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) - assert store_not_open._is_open + assert store_not_open._is_open is True observed = await self.get(store_not_open, "test/key") assert observed.to_bytes() == b"XXAAAAAAAA" @@ -225,7 +225,7 @@ def test_set_range_sync_not_open(self, store_not_open: LocalStore) -> None: assert not store_not_open._is_open sync(self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))) store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) - assert store_not_open._is_open + assert store_not_open._is_open is True observed = sync(self.get(store_not_open, "test/key")) assert observed.to_bytes() == b"XXAAAAAAAA" diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index dc63644c2c..9d58d8a1c6 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -181,7 +181,7 @@ async def test_set_range_not_open(self, store_not_open: MemoryStore) -> None: assert not store_not_open._is_open await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) - assert store_not_open._is_open + assert store_not_open._is_open is True observed = await self.get(store_not_open, "test/key") assert observed.to_bytes() == b"XXAAAAAAAA" @@ -190,7 +190,7 @@ def test_set_range_sync_not_open(self, store_not_open: MemoryStore) -> None: assert not store_not_open._is_open store_not_open._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA") store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) - assert store_not_open._is_open + assert store_not_open._is_open is True assert store_not_open._store_dict["test/key"].to_bytes() == b"XXAAAAAAAA" From a9da33aab157b5637b87f1c9d34cd10e74fc8a5e Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 1 May 2026 08:41:38 -0400 Subject: [PATCH 06/10] chore: mypy --- tests/test_store/test_local.py | 4 ++-- tests/test_store/test_memory.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index d28fd71211..22f17ef87e 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -216,7 +216,7 @@ async def test_set_range_not_open(self, store_not_open: LocalStore) -> None: assert not store_not_open._is_open await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) - assert store_not_open._is_open is True + assert getattr(store_not_open, "_is_open") # noqa: B009 observed = await self.get(store_not_open, "test/key") assert observed.to_bytes() == b"XXAAAAAAAA" @@ -225,7 +225,7 @@ def test_set_range_sync_not_open(self, store_not_open: LocalStore) -> None: assert not store_not_open._is_open sync(self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))) store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) - assert store_not_open._is_open is True + assert getattr(store_not_open, "_is_open") # noqa: B009 observed = sync(self.get(store_not_open, "test/key")) assert observed.to_bytes() == b"XXAAAAAAAA" diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 12968618e6..5962dcb8f2 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -181,7 +181,7 @@ async def test_set_range_not_open(self, store_not_open: MemoryStore) -> None: assert not store_not_open._is_open await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) - assert store_not_open._is_open is True + assert getattr(store_not_open, "_is_open") # noqa: B009 observed = await self.get(store_not_open, "test/key") assert observed.to_bytes() == b"XXAAAAAAAA" @@ -190,7 +190,7 @@ def test_set_range_sync_not_open(self, store_not_open: MemoryStore) -> None: assert not store_not_open._is_open store_not_open._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA") store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) - assert store_not_open._is_open is True + assert getattr(store_not_open, "_is_open") # noqa: B009 assert store_not_open._store_dict["test/key"].to_bytes() == b"XXAAAAAAAA" From 225cea0d7c21a3b2deb2e7b6f1db788c5c12b782 Mon Sep 17 00:00:00 2001 From: Davis Bennett Date: Sat, 30 May 2026 21:39:09 +0200 Subject: [PATCH 07/10] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/zarr/storage/_local.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index fa1266286c..f9849a343d 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -80,9 +80,10 @@ def _atomic_write( def _put_range(path: Path, value: Buffer, start: int) -> None: """Write bytes at a specific offset within an existing file.""" + view = value.as_buffer_like() with path.open("r+b") as f: f.seek(start) - f.write(value.as_numpy_array().tobytes()) + f.write(view) def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: From 2db2c9a034fdb24a2b52ec4615bf546073ae3c97 Mon Sep 17 00:00:00 2001 From: Mark Kittisopikul Date: Sun, 31 May 2026 13:21:17 -0400 Subject: [PATCH 08/10] feat: add per-key locking to MemoryStore and LocalStore set_range (#178) --- src/zarr/storage/_local.py | 85 ++++++++++++++++++++++++++++++++- src/zarr/storage/_memory.py | 17 ++++++- tests/test_store/test_local.py | 55 +++++++++++++++++++++ tests/test_store/test_memory.py | 48 +++++++++++++++++++ 4 files changed, 201 insertions(+), 4 deletions(-) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f9849a343d..a677587b28 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -6,6 +6,8 @@ import os import shutil import sys +import threading +import time import uuid from pathlib import Path from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self @@ -59,6 +61,18 @@ def _safe_move(src: Path, dst: Path) -> None: os.unlink(src) +_LOCK_POLL_INTERVAL = 0.01 # seconds between lock-file existence checks +_LOCK_STALE_TIMEOUT = 60.0 # seconds before an abandoned lock file is reclaimed + + +def _is_stale_lock(lock_path: Path) -> bool: + """Return True if lock_path either doesn't exist or is older than _LOCK_STALE_TIMEOUT.""" + try: + return time.time() - lock_path.stat().st_mtime > _LOCK_STALE_TIMEOUT + except FileNotFoundError: + return True + + @contextlib.contextmanager def _atomic_write( path: Path, @@ -118,6 +132,8 @@ class LocalStore(Store, SupportsSetRange): supports_listing: bool = True root: Path + _key_locks: dict[str, asyncio.Lock] + _key_locks_sync: dict[str, threading.Lock] def __init__(self, root: Path | str, *, read_only: bool = False) -> None: super().__init__(read_only=read_only) @@ -128,6 +144,8 @@ def __init__(self, root: Path | str, *, read_only: bool = False) -> None: f"'root' must be a string or Path instance. Got an instance of {type(root)} instead." ) self.root = root + self._key_locks = {} + self._key_locks_sync = {} def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited @@ -306,13 +324,76 @@ async def set_range(self, key: str, value: Buffer, start: int) -> None: await self._open() self._check_writable() path = self.root / key - await asyncio.to_thread(_put_range, path, value, start) + lock_path = path.with_name(path.name + ".__lock__") + in_process_lock = self._key_locks.setdefault(key, asyncio.Lock()) + + # Acquire the file lock (steps 1-5 from the concurrency plan). + while True: + # Step 1: spin-wait until no lock file is present (or it is stale). + while await asyncio.to_thread(lock_path.exists): + if await asyncio.to_thread(_is_stale_lock, lock_path): + break + await asyncio.sleep(_LOCK_POLL_INTERVAL) + + # Steps 2-5: serialise the rename under an in-process lock so that + # only one coroutine per process attempts the atomic file move at a time. + acquired = False + async with in_process_lock: + # Step 3: re-check after acquiring the in-process lock. + if not await asyncio.to_thread(lock_path.exists): + try: + # Step 4: atomic rename — raises FileExistsError if another + # process grabbed the lock between steps 3 and 4. + await asyncio.to_thread(_safe_move, path, lock_path) + acquired = True + except FileExistsError: + pass + # Step 5: in-process lock released on context exit. + + if acquired: + break + + # Step 6: perform the partial write on the lock file. + try: + await asyncio.to_thread(_put_range, lock_path, value, start) + finally: + # Steps 7-9: re-acquire in-process lock, rename lock file back, release. + async with in_process_lock: + await asyncio.to_thread(lock_path.replace, path) def set_range_sync(self, key: str, value: Buffer, start: int) -> None: self._ensure_open_sync() self._check_writable() path = self.root / key - _put_range(path, value, start) + lock_path = path.with_name(path.name + ".__lock__") + in_process_lock = self._key_locks_sync.setdefault(key, threading.Lock()) + + # Acquire the file lock (same double-checked pattern as the async path). + while True: + # Step 1: spin-wait. + while lock_path.exists(): + if _is_stale_lock(lock_path): + break + time.sleep(_LOCK_POLL_INTERVAL) + + acquired = False + with in_process_lock: + if not lock_path.exists(): + try: + _safe_move(path, lock_path) + acquired = True + except FileExistsError: + pass + + if acquired: + break + + # Partial write, then release. + try: + _put_range(lock_path, value, start) + finally: + with in_process_lock: + lock_path.replace(path) async def delete(self, key: str) -> None: """ diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 54cf300098..f6bca4afc9 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import os import threading import weakref @@ -49,6 +50,8 @@ class MemoryStore(Store, SupportsSetRange): supports_listing: bool = True _store_dict: MutableMapping[str, Buffer] + _key_locks: dict[str, asyncio.Lock] + _key_locks_sync: dict[str, threading.Lock] def __init__( self, @@ -60,6 +63,8 @@ def __init__( if store_dict is None: store_dict = {} self._store_dict = store_dict + self._key_locks = {} + self._key_locks_sync = {} def with_read_only(self, read_only: bool = False) -> MemoryStore: # docstring inherited @@ -206,13 +211,17 @@ def _set_range_impl(self, key: str, value: Buffer, start: int) -> None: async def set_range(self, key: str, value: Buffer, start: int) -> None: self._check_writable() await self._ensure_open() - self._set_range_impl(key, value, start) + lock = self._key_locks.setdefault(key, asyncio.Lock()) + async with lock: + self._set_range_impl(key, value, start) def set_range_sync(self, key: str, value: Buffer, start: int) -> None: self._check_writable() if not self._is_open: self._is_open = True - self._set_range_impl(key, value, start) + lock = self._key_locks_sync.setdefault(key, threading.Lock()) + with lock: + self._set_range_impl(key, value, start) async def list(self) -> AsyncIterator[str]: # docstring inherited @@ -729,6 +738,8 @@ def __init__(self, name: str | None = None, *, path: str = "", read_only: bool = # Get or create a managed dict from the registry self._store_dict, self._name = _managed_store_dict_registry.get_or_create(name) self.path = normalize_path(path) + self._key_locks = {} + self._key_locks_sync = {} def __str__(self) -> str: return _join_paths([f"memory://{self._name}", self.path]) @@ -764,6 +775,8 @@ def _from_managed_dict( store._store_dict = managed_dict store._name = name store.path = normalize_path(path) + store._key_locks = {} + store._key_locks_sync = {} return store def with_read_only(self, read_only: bool = False) -> ManagedMemoryStore: diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 22f17ef87e..8c6e37b8f4 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -1,8 +1,10 @@ from __future__ import annotations +import asyncio import json import pathlib import re +import threading from typing import TYPE_CHECKING import numpy as np @@ -229,6 +231,59 @@ def test_set_range_sync_not_open(self, store_not_open: LocalStore) -> None: observed = sync(self.get(store_not_open, "test/key")) assert observed.to_bytes() == b"XXAAAAAAAA" + async def test_set_range_concurrent(self, store: LocalStore) -> None: + """Concurrent set_range calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + await store.set("test/key", cpu.Buffer.from_bytes(bytes(total))) + + async def write_chunk(i: int) -> None: + data = bytes([i] * chunk_size) + await store.set_range("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + + await asyncio.gather(*[write_chunk(i) for i in range(n_writers)]) + + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + + def test_set_range_sync_concurrent(self, store: LocalStore) -> None: + """Concurrent set_range_sync calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + sync(store.set("test/key", cpu.Buffer.from_bytes(bytes(total)))) + + errors: list[Exception] = [] + + def write_chunk(i: int) -> None: + try: + data = bytes([i] * chunk_size) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=write_chunk, args=(i,)) for i in range(n_writers)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + + def test_lock_file_cleaned_up(self, store: LocalStore) -> None: + """No lock file should remain after set_range_sync completes.""" + sync(store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + lock_path = store.root / "test" / "key.__lock__" + assert not lock_path.exists() + @pytest.mark.parametrize("exclusive", [True, False]) def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None: diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 5962dcb8f2..0d0a007b4e 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio import json import re +import threading from typing import TYPE_CHECKING, Any import numpy as np @@ -193,6 +195,52 @@ def test_set_range_sync_not_open(self, store_not_open: MemoryStore) -> None: assert getattr(store_not_open, "_is_open") # noqa: B009 assert store_not_open._store_dict["test/key"].to_bytes() == b"XXAAAAAAAA" + async def test_set_range_concurrent(self, store: MemoryStore) -> None: + """Concurrent set_range calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + await store.set("test/key", cpu.Buffer.from_bytes(bytes(total))) + + async def write_chunk(i: int) -> None: + data = bytes([i] * chunk_size) + await store.set_range("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + + await asyncio.gather(*[write_chunk(i) for i in range(n_writers)]) + + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + + def test_set_range_sync_concurrent(self, store: MemoryStore) -> None: + """Concurrent set_range_sync calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + store._store_dict["test/key"] = cpu.Buffer.from_bytes(bytes(total)) + + errors: list[Exception] = [] + + def write_chunk(i: int) -> None: + try: + data = bytes([i] * chunk_size) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=write_chunk, args=(i,)) for i in range(n_writers)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + # TODO: fix this warning @pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning") From e02437cd59f2a8959408f684b66bd35d0b976505 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 3 Jun 2026 20:49:01 +0200 Subject: [PATCH 09/10] fix: set_range capability detection, bounds check, and managed path prefix Layered on top of the per-key locking from #178. Addresses branch-review findings not covered by that change: - GpuMemoryStore must not satisfy SupportsSetRange: the inherited MemoryStore implementation mutates a host copy of the GPU buffer and would silently lose the write. MemoryStore/LocalStore now satisfy the protocol *structurally* (dropped the nominal `SupportsSetRange` base), so GpuMemoryStore disclaims set_range/set_range_sync by setting them to None and isinstance(gpu_store, SupportsSetRange) is now False. - Uniform out-of-bounds handling: add _check_set_range_bounds (storage/_utils), enforcing start + len(value) <= len(existing) in both MemoryStore and LocalStore so a too-large write raises a clear ValueError instead of LocalStore silently extending the file or MemoryStore raising an opaque numpy error. Matches zarrs, which validates ranges rather than extending. - ManagedMemoryStore.set_range/set_range_sync now apply the store's path prefix like its other key-bearing methods, so a path-prefixed store targets the correct key. - Expand the SupportsSetRange docstring with the caller-owns-consistency concurrency contract and record the opt-in-protocol-vs-universal-method design decision; update the changelog. Adds out-of-bounds tests (MemoryStore, LocalStore), a GpuMemoryStore capability test, and a path-prefixed ManagedMemoryStore set_range test. Verified on GPU hardware. Co-Authored-By: Claude Opus 4.8 (1M context) --- changes/3907.feature.md | 6 +++- src/zarr/abc/store.py | 51 +++++++++++++++++++++++++++++---- src/zarr/storage/_local.py | 10 +++++-- src/zarr/storage/_memory.py | 27 +++++++++++++++-- src/zarr/storage/_utils.py | 17 +++++++++++ tests/test_store/test_local.py | 17 +++++++++++ tests/test_store/test_memory.py | 46 +++++++++++++++++++++++++++++ 7 files changed, 163 insertions(+), 11 deletions(-) diff --git a/changes/3907.feature.md b/changes/3907.feature.md index 66b908d305..e9fa75d632 100644 --- a/changes/3907.feature.md +++ b/changes/3907.feature.md @@ -1 +1,5 @@ -Add protocols for stores that support byte-range-writes. This is necessary to support in-place writes of sharded arrays. \ No newline at end of file +Add the `SupportsSetRange` protocol for stores that support writing to a byte range within an existing value, implemented by `LocalStore` and `MemoryStore`. This is necessary to support in-place writes of sharded arrays (e.g. writing a single subchunk without rewriting the entire shard). + +Byte-range writes are exposed as an opt-in protocol rather than a method on the `Store` ABC. Only a few stores can perform them natively, and most cannot. A universal method with a read-modify-write fallback (as in the Rust `zarrs` crate) would let every store participate, but for the motivating use case that fallback would silently rewrite an entire shard, defeating the purpose. The opt-in protocol keeps the cost model honest and keeps `set_range` out of the signatures of stores that will never support it; any fallback strategy is left to the caller (the sharding codec). Stores satisfy the protocol structurally, so `GpuMemoryStore` (which has no use case for in-place GPU byte-range writes) disclaims it and is correctly reported as unsupported by `isinstance`. + +It is entirely the caller's responsibility to ensure consistency: concurrent writes to overlapping ranges are order-dependent, `set_range` racing against `set`/`delete` is a race, and writes are not guaranteed to be atomic with respect to a process crash. A write that does not fit within the existing value raises `ValueError` consistently across implementations. diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 60a0e938cd..0384f391e8 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -788,16 +788,55 @@ async def delete(self) -> None: ... async def set_if_not_exists(self, default: Buffer) -> None: ... +# Design note: byte-range writes are exposed as an opt-in protocol rather than a +# method on the `Store` ABC. Only a few stores can do them natively (`LocalStore`, +# `MemoryStore`); most (cloud, zip, read-only) cannot. A universal `Store.set_range` +# with a read-modify-write fallback (as in the Rust `zarrs` crate's +# `WritableStorageTraits::set_partial` + `supports_set_partial`) would let every +# store participate, but for our motivating use case — writing one subchunk without +# rewriting the whole shard — that fallback is a footgun: it would silently rewrite +# an entire (possibly multi-GB) shard, defeating the purpose while appearing to +# succeed. An opt-in protocol instead keeps the cost model honest (a store either +# supports cheap ranged writes or doesn't advertise the capability at all) and keeps +# `set_range` out of the signatures of stores that will never support it. +# +# Stores satisfy this protocol *structurally* (by defining the methods), not by +# nominal inheritance, so a subclass can disclaim it by setting the methods to `None` +# (see `GpuMemoryStore`). Any read-modify-write fallback strategy belongs in the +# caller (the sharding codec), which already has to decide between in-place and +# buffer-and-rewrite — mirroring the zarrs layering (storage writes bytes, codec owns +# strategy) without making every store carry the method. If broad-backend partial +# encoding is wanted later, adding a `supports_set_range()` capability flag plus a +# codec-level fallback is an additive change that does not require retrofitting stores. @runtime_checkable class SupportsSetRange(Protocol): """Protocol for stores that support writing to a byte range within an existing value. - Overwrites ``len(value)`` bytes starting at byte offset ``start`` within the - existing stored value for ``key``. The key must already exist and the write - must fit within the existing value (i.e., ``start + len(value) <= len(existing)``). - - Behavior when the write extends past the end of the existing value is - implementation-specific and should not be relied upon. + Overwrites `len(value)` bytes starting at byte offset `start` within the + existing stored value for `key`. The key must already exist and the write + must fit within the existing value (i.e., `start + len(value) <= len(existing)`); + a write that does not fit raises `ValueError`. + + Concurrency and atomicity + ------------------------- + **It is entirely the caller's responsibility to ensure consistency.** Any + coordination needed to keep stored values consistent must be arranged by the + caller. In particular: + + - Concurrent `set_range` calls that write to **disjoint** byte ranges of the + same key are safe. + - Concurrent `set_range` calls that write to **overlapping** ranges of the same + key have order-dependent, unspecified results. The caller must serialize them. + - A `set_range` racing against a `set` or `delete` on the same key is a race + condition, just as concurrent `set` calls are. The caller must serialize these. + - Writes are **not** guaranteed to be atomic with respect to a process crash: + a crash mid-write may leave the value partially updated. The caller is + responsible for any durability or recovery guarantees it requires. + + What an implementation does to honor (or fall short of) this contract — locking, + atomic replacement, and so on — is documented on the implementing store, not here. + The intended consumer (the sharding codec writing inner chunks of deterministic + compressed size) coordinates writes so that they target disjoint ranges. """ async def set_range(self, key: str, value: Buffer, start: int) -> None: ... diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index a677587b28..3c802db9be 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -18,11 +18,11 @@ RangeByteRequest, Store, SuffixByteRequest, - SupportsSetRange, ) from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import AccessModeLiteral, concurrent_map +from zarr.storage._utils import _check_set_range_bounds if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator @@ -96,6 +96,12 @@ def _put_range(path: Path, value: Buffer, start: int) -> None: """Write bytes at a specific offset within an existing file.""" view = value.as_buffer_like() with path.open("r+b") as f: + # Validate bounds before writing: a bare seek+write would silently extend the + # file (zero-filling any gap), but the SupportsSetRange contract requires the + # write to fit within the existing value, so we fail consistently with + # MemoryStore instead. + existing_length = f.seek(0, os.SEEK_END) + _check_set_range_bounds(existing_length, start, len(value)) f.seek(start) f.write(view) @@ -108,7 +114,7 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: return f.write(view) -class LocalStore(Store, SupportsSetRange): +class LocalStore(Store): """ Store for the local file system. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index f6bca4afc9..b05291bf91 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -7,11 +7,12 @@ from logging import getLogger from typing import TYPE_CHECKING, Any, Self -from zarr.abc.store import ByteRequest, Store, SupportsSetRange +from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.storage._utils import ( + _check_set_range_bounds, _join_paths, _normalize_byte_range_index, normalize_path, @@ -27,7 +28,7 @@ logger = getLogger(__name__) -class MemoryStore(Store, SupportsSetRange): +class MemoryStore(Store): """ Store for local memory. @@ -202,6 +203,7 @@ async def delete(self, key: str) -> None: def _set_range_impl(self, key: str, value: Buffer, start: int) -> None: buf = self._store_dict[key] target = buf.as_numpy_array() + _check_set_range_bounds(len(target), start, len(value)) if not target.flags.writeable: target = target.copy() self._store_dict[key] = buf.__class__(target) @@ -566,6 +568,19 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None gpu_value = value if isinstance(value, gpu.Buffer) else gpu.Buffer.from_buffer(value) await super().set(key, gpu_value, byte_range=byte_range) + # ``GpuMemoryStore`` deliberately does not support byte-range writes, so it must + # not satisfy ``SupportsSetRange``. The inherited ``MemoryStore`` implementation + # mutates a *host* copy of the GPU buffer (via ``as_numpy_array``) and would + # silently lose the write, and there is no use case for in-place byte-range writes + # into GPU memory (the intended ``set_range`` consumer targets cpu/local storage). + # Disclaiming the inherited methods by setting them to ``None`` makes + # ``isinstance(gpu_store, SupportsSetRange)`` return ``False`` (``runtime_checkable`` + # treats a ``None`` attribute as "method absent"). This works because + # ``MemoryStore`` satisfies the protocol structurally rather than by nominal + # inheritance; mirrors the ``__hash__ = None`` idiom. + set_range = None # type: ignore[assignment] + set_range_sync = None # type: ignore[assignment] + # ----------------------------------------------------------------------------- # ManagedMemoryStore and its registry @@ -860,6 +875,14 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited return await super().set_if_not_exists(_join_paths([self.path, key]), value) + async def set_range(self, key: str, value: Buffer, start: int) -> None: + # docstring inherited + return await super().set_range(_join_paths([self.path, key]), value, start) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + # docstring inherited + return super().set_range_sync(_join_paths([self.path, key]), value, start) + async def delete(self, key: str) -> None: # docstring inherited return await super().delete(_join_paths([self.path, key])) diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 1f8e9b0a29..ef7bdf296e 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -160,6 +160,23 @@ def _normalize_byte_range_index(data: Buffer, byte_range: ByteRequest | None) -> return (start, stop) +def _check_set_range_bounds(existing_length: int, start: int, value_length: int) -> None: + """ + Validate that a ``set_range`` write fits within an existing value. + + Stores implementing ``SupportsSetRange`` use this so the out-of-bounds case fails + the same way everywhere (a clear ``ValueError``) rather than silently extending the + value (as a file seek+write would) or raising an opaque numpy shape error. + """ + if start < 0: + raise ValueError(f"set_range start must be non-negative, got {start}.") + if start + value_length > existing_length: + raise ValueError( + f"set_range write of {value_length} bytes at offset {start} does not fit " + f"within the existing value of length {existing_length}." + ) + + def _join_paths(paths: Iterable[str]) -> str: """ Filter out instances of '' and join the remaining strings with '/'. diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 8c6e37b8f4..0661208102 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -213,6 +213,23 @@ def test_set_range_sync( assert result is not None assert result.to_bytes() == expected + @pytest.mark.parametrize( + ("start", "patch"), + [(9, b"XX"), (10, b"X"), (0, b"ZZZZZZZZZZZ")], + ids=["overhang", "past-end", "too-long"], + ) + async def test_set_range_out_of_bounds( + self, store: LocalStore, start: int, patch: bytes + ) -> None: + """A write that does not fit within the existing value raises, not extends.""" + await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + with pytest.raises(ValueError, match="does not fit within the existing value"): + await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start) + # The file is left unchanged (not zero-filled / extended). + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == b"AAAAAAAAAA" + async def test_set_range_not_open(self, store_not_open: LocalStore) -> None: """set_range auto-opens a closed store.""" assert not store_not_open._is_open diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 0d0a007b4e..6a57a54bd3 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -178,6 +178,20 @@ def test_set_range_sync( assert result is not None assert result.to_bytes() == expected + @pytest.mark.parametrize( + ("start", "patch"), + [(9, b"XX"), (10, b"X"), (0, b"ZZZZZZZZZZZ")], + ids=["overhang", "past-end", "too-long"], + ) + async def test_set_range_out_of_bounds( + self, store: MemoryStore, start: int, patch: bytes + ) -> None: + """A write that does not fit within the existing value raises consistently.""" + store._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA") + with pytest.raises(ValueError, match="does not fit within the existing value"): + await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start) + assert store._store_dict["test/key"].to_bytes() == b"AAAAAAAAAA" + async def test_set_range_not_open(self, store_not_open: MemoryStore) -> None: """set_range auto-opens a closed store.""" assert not store_not_open._is_open @@ -296,6 +310,18 @@ def test_from_dict(self) -> None: for v in result._store_dict.values(): assert type(v) is gpu.Buffer + def test_set_range_not_supported(self, store: GpuMemoryStore) -> None: + """GpuMemoryStore deliberately does not satisfy SupportsSetRange. + + Capability detection via isinstance must report False so a consumer (e.g. the + sharding codec) does not select it and crash. The methods are disclaimed + (set to None), so isinstance returns False rather than a false positive. + """ + # mypy statically knows GpuMemoryStore cannot satisfy the protocol (the methods + # are None), which is exactly what we want — but it then flags this runtime + # assertion as unreachable. Keep the runtime check as a regression guard. + assert not isinstance(store, SupportsSetRange) # type: ignore[unreachable] + class TestManagedMemoryStore(StoreTests[ManagedMemoryStore, cpu.Buffer]): store_cls = ManagedMemoryStore @@ -565,6 +591,26 @@ async def test_path_prefix_operations(self) -> None: assert result2 is not None assert result2.to_bytes() == b"value" + def test_supports_set_range(self, store: ManagedMemoryStore) -> None: + assert isinstance(store, SupportsSetRange) + + async def test_set_range_applies_path_prefix(self) -> None: + """set_range must prepend the store's path prefix, matching set/get. + + Regression: an unprefixed inherited set_range would target the wrong key. + """ + store = ManagedMemoryStore(name="set-range-path-test", path="subdir") + await store.set("k", self.buffer_cls.from_bytes(b"AAAAAAAAAA")) + # set() writes to the prefixed backing key. + assert "subdir/k" in store._store_dict + await store.set_range("k", self.buffer_cls.from_bytes(b"XX"), start=2) + store.set_range_sync("k", self.buffer_cls.from_bytes(b"YY"), start=6) + # Both writes landed on the same prefixed value that set/get use. + observed = await store.get("k") + assert observed is not None + assert observed.to_bytes() == b"AAXXAAYYAA" + assert store._store_dict["subdir/k"].to_bytes() == b"AAXXAAYYAA" + async def test_path_list_operations(self) -> None: """Test that list operations filter by path prefix.""" store = ManagedMemoryStore(name="list-test") From a8f1fe712aa0eaf22d4ac744dd74906cf1abbb75 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 3 Jun 2026 21:08:15 +0200 Subject: [PATCH 10/10] fix: share managed-store set_range locks; isolate from_url filesystems MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses two branch-review findings: - ManagedMemoryStore instances sharing a name share their backing dict via the registry but each built its own _key_locks/_key_locks_sync, so concurrent set_range from two handles to the same dict did not actually serialize. Move the per-key lock dicts into the registry, keyed by name, so all handles bound to a name (including with_read_only clones and unpickled handles) share them. Stale entries are pruned when the backing dict is collected. - FsspecStore.from_url claimed ownership of the filesystem (closing it on close()), but fsspec.url_to_fs returns instance-cached filesystems by default, so two from_url calls for the same URL shared one fs — closing one tore down the shared aiohttp session under the other (and any other fsspec consumer). Pass skip_instance_cache=True so each from_url store owns a private fs instance; a user-supplied storage_options can still override it. Adds regression tests: managed locks shared by name, and from_url instances distinct (closing one leaves the other usable). Co-Authored-By: Claude Opus 4.8 (1M context) --- src/zarr/storage/_fsspec.py | 9 ++++++++- src/zarr/storage/_memory.py | 36 +++++++++++++++++++++++++++++---- tests/test_store/test_fsspec.py | 22 ++++++++++++++++++++ tests/test_store/test_memory.py | 16 +++++++++++++++ 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 29201a6fee..0493e2cf74 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -266,7 +266,14 @@ def from_url( from fsspec.core import url_to_fs opts = storage_options or {} - opts = {"asynchronous": True, **opts} + # ``skip_instance_cache=True`` forces a fresh filesystem instance instead of + # an fsspec instance-cached one. Without it, two ``from_url`` calls with the + # same URL/options receive the *same* cached ``AsyncFileSystem``; closing one + # store (which we mark as owning the fs) would tear down the shared aiohttp + # session out from under the other store — and any other fsspec consumer in + # the process. By skipping the cache we own an instance no one else shares, so + # ``close()`` is safe. + opts = {"asynchronous": True, "skip_instance_cache": True, **opts} fs, path = url_to_fs(url, **opts) if not fs.async_impl: diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index b05291bf91..763453de0a 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -614,6 +614,13 @@ def __init__(self) -> None: self._registry: weakref.WeakValueDictionary[str, _ManagedStoreDict] = ( weakref.WeakValueDictionary() ) + # set_range per-key lock dicts, shared by name so every ManagedMemoryStore + # bound to the same backing dict serializes set_range against the others (the + # registry shares the data, so it must share the locks too). Keyed by name + # (_ManagedStoreDict is an unhashable dict subclass) and pruned when the + # corresponding dict has been collected. + self._key_locks: dict[str, dict[str, asyncio.Lock]] = {} + self._key_locks_sync: dict[str, dict[str, threading.Lock]] = {} self._counter = 0 self._lock = threading.Lock() @@ -683,6 +690,24 @@ def get(self, name: str) -> _ManagedStoreDict | None: """ return self._registry.get(name) + def get_key_locks(self, name: str) -> tuple[dict[str, asyncio.Lock], dict[str, threading.Lock]]: + """ + Get the shared set_range per-key lock dicts for a store name. + + All ManagedMemoryStore instances resolving to the same name get the same lock + dicts, so their set_range calls serialize against each other. Created on first + use; stale entries (whose backing dict has been collected) are pruned. + """ + with self._lock: + stale = [n for n in self._key_locks if n not in self._registry] + for n in stale: + del self._key_locks[n] + self._key_locks_sync.pop(n, None) + return ( + self._key_locks.setdefault(name, {}), + self._key_locks_sync.setdefault(name, {}), + ) + _managed_store_dict_registry = _ManagedStoreDictRegistry() @@ -753,8 +778,12 @@ def __init__(self, name: str | None = None, *, path: str = "", read_only: bool = # Get or create a managed dict from the registry self._store_dict, self._name = _managed_store_dict_registry.get_or_create(name) self.path = normalize_path(path) - self._key_locks = {} - self._key_locks_sync = {} + # Share the per-key set_range locks with every other store backed by the same + # dict, so concurrent set_range from different handles to the same name actually + # serialize. + self._key_locks, self._key_locks_sync = _managed_store_dict_registry.get_key_locks( + self._name + ) def __str__(self) -> str: return _join_paths([f"memory://{self._name}", self.path]) @@ -790,8 +819,7 @@ def _from_managed_dict( store._store_dict = managed_dict store._name = name store.path = normalize_path(path) - store._key_locks = {} - store._key_locks_sync = {} + store._key_locks, store._key_locks_sync = _managed_store_dict_registry.get_key_locks(name) return store def with_read_only(self, read_only: bool = False) -> ManagedMemoryStore: diff --git a/tests/test_store/test_fsspec.py b/tests/test_store/test_fsspec.py index 8006470174..fa8426448d 100644 --- a/tests/test_store/test_fsspec.py +++ b/tests/test_store/test_fsspec.py @@ -313,6 +313,28 @@ async def test_from_url_close_releases_store(self) -> None: assert not store._is_open + async def test_from_url_uses_distinct_filesystem_instances(self) -> None: + """Two from_url() calls for the same URL must not share a cached fs. + + Regression: from_url claims ownership and closes the fs on close(); if it used + the fsspec instance cache, two stores would share one fs and closing one would + tear the shared session out from under the other. skip_instance_cache=True must + give each store its own fs. + """ + url = f"s3://{test_bucket_name}/distinct/" + opts = {"endpoint_url": endpoint_url, "anon": False} + store_a = FsspecStore.from_url(url, storage_options=opts) + store_b = FsspecStore.from_url(url, storage_options=opts) + assert store_a.fs is not store_b.fs + # Closing one leaves the other fully usable. + await store_a.set("probe", cpu.Buffer.from_bytes(b"x")) + store_a.close() + await store_b.set("probe", cpu.Buffer.from_bytes(b"y")) + result = await store_b.get("probe", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == b"y" + store_b.close() + def test_direct_construction_does_not_own_filesystem(self) -> None: """Direct FsspecStore() must not claim ownership — the caller owns the fs.""" try: diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 6a57a54bd3..c49d2abb36 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -611,6 +611,22 @@ async def test_set_range_applies_path_prefix(self) -> None: assert observed.to_bytes() == b"AAXXAAYYAA" assert store._store_dict["subdir/k"].to_bytes() == b"AAXXAAYYAA" + def test_set_range_locks_shared_by_name(self) -> None: + """Instances sharing a backing dict (same name) share the set_range lock dicts. + + The registry shares the data across same-name handles, so it must share the + locks too — otherwise concurrent set_range from two handles would not serialize. + """ + a = ManagedMemoryStore(name="lock-share-test") + b = ManagedMemoryStore.from_url("memory://lock-share-test") + c = a.with_read_only(not a.read_only) + assert a._key_locks is b._key_locks + assert a._key_locks_sync is b._key_locks_sync + assert a._key_locks is c._key_locks + # A differently named store has independent locks. + other = ManagedMemoryStore(name="lock-share-test-other") + assert other._key_locks is not a._key_locks + async def test_path_list_operations(self) -> None: """Test that list operations filter by path prefix.""" store = ManagedMemoryStore(name="list-test")