Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions changes/4074.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Fixed several storage and codec bugs:

- Reading a value with a `SuffixByteRequest` larger than the value now correctly returns the whole value (matching HTTP `bytes=-N` suffix-range semantics), instead of silently returning incorrect data for `MemoryStore`.
- `LoggingStore.get_partial_values` and `FsspecStore.get_partial_values` no longer return empty results when `key_ranges` is passed as a one-shot iterable (e.g. a generator).
- `Store.getsize_prefix` no longer over-counts sibling keys that merely share a string prefix (e.g. `getsize_prefix("foo")` no longer includes keys under `foobar/`).
- `ZipStore.close()` no longer raises `AttributeError` when the store was created but never opened (including when used as a context manager without any I/O).
- `codecs_from_list` now raises a descriptive `TypeError` when a `BytesBytesCodec` immediately follows an `ArrayArrayCodec`, instead of a misleading "Required ArrayBytesCodec was not found" `ValueError`.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ show_error_code_links = true
show_error_context = true
strict = true
warn_unreachable = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool", "truthy-iterable"]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flagging this change to our mypy config. the fsspecstore codebase was relying on if Iterator: ... which I think we should avoid. so now mypy helps us avoid it.


[[tool.mypy.overrides]]
module = [
Expand Down
2 changes: 2 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ async def getsize_prefix(self, prefix: str) -> int:
from zarr.core.common import concurrent_map
from zarr.core.config import config

if prefix != "" and not prefix.endswith("/"):
prefix += "/"
keys = [(x,) async for x in self.list_prefix(prefix)]
limit = config.get("async.concurrency")
sizes = await concurrent_map(keys, self.getsize, limit=limit)
Expand Down
1 change: 1 addition & 0 deletions src/zarr/core/codec_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ def codecs_from_list(
"must be preceded by either another BytesBytesCodec, or an ArrayBytesCodec. "
f"Got {type(prev_codec)} instead."
)
raise TypeError(msg)
bytes_bytes += (cur_codec,)
else:
raise TypeError
Expand Down
47 changes: 24 additions & 23 deletions src/zarr/storage/_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,30 +424,31 @@ async def get_partial_values(
key_ranges: Iterable[tuple[str, ByteRequest | None]],
) -> list[Buffer | None]:
# docstring inherited
if key_ranges:
# _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest.
key_ranges = list(key_ranges)
paths: list[str] = []
starts: list[int | None] = []
stops: list[int | None] = []
for key, byte_range in key_ranges:
paths.append(_dereference_path(self.path, key))
if byte_range is None:
starts.append(None)
stops.append(None)
elif isinstance(byte_range, RangeByteRequest):
starts.append(byte_range.start)
stops.append(byte_range.end)
elif isinstance(byte_range, OffsetByteRequest):
starts.append(byte_range.offset)
stops.append(None)
elif isinstance(byte_range, SuffixByteRequest):
starts.append(-byte_range.suffix)
stops.append(None)
else:
raise ValueError(f"Unexpected byte_range, got {byte_range}.")
else:
# Materialise first: key_ranges may be a one-shot iterable, so a bare
# truthiness check (e.g. `if key_ranges`) would be unreliable for an
# empty generator. _cat_ranges also expects lists of paths/starts/stops.
key_ranges = list(key_ranges)
if not key_ranges:
return []
paths: list[str] = []
starts: list[int | None] = []
stops: list[int | None] = []
for key, byte_range in key_ranges:
paths.append(_dereference_path(self.path, key))
if byte_range is None:
starts.append(None)
stops.append(None)
elif isinstance(byte_range, RangeByteRequest):
starts.append(byte_range.start)
stops.append(byte_range.end)
elif isinstance(byte_range, OffsetByteRequest):
starts.append(byte_range.offset)
stops.append(None)
elif isinstance(byte_range, SuffixByteRequest):
starts.append(-byte_range.suffix)
stops.append(None)
else:
raise ValueError(f"Unexpected byte_range, got {byte_range}.")
# TODO: expectations for exceptions or missing keys?
res = await self.fs._cat_ranges(paths, starts, stops, on_error="return")
# the following is an s3-specific condition we probably don't want to leak
Expand Down
1 change: 1 addition & 0 deletions src/zarr/storage/_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ async def get_partial_values(
key_ranges: Iterable[tuple[str, ByteRequest | None]],
) -> list[Buffer | None]:
# docstring inherited
key_ranges = list(key_ranges)
keys = ",".join([k[0] for k in key_ranges])
with self.log(keys):
return await self._store.get_partial_values(prototype=prototype, key_ranges=key_ranges)
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/storage/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _normalize_byte_range_index(data: Buffer, byte_range: ByteRequest | None) ->
start = byte_range.offset
stop = len(data) + 1
elif isinstance(byte_range, SuffixByteRequest):
start = len(data) - byte_range.suffix
start = max(0, len(data) - byte_range.suffix)
stop = len(data) + 1
else:
raise ValueError(f"Unexpected byte_range, got {byte_range}.")
Expand Down
2 changes: 2 additions & 0 deletions src/zarr/storage/_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def __setstate__(self, state: dict[str, Any]) -> None:

def close(self) -> None:
# docstring inherited
if not self._is_open:
return
super().close()
with self._lock:
self._zf.close()
Expand Down
53 changes: 46 additions & 7 deletions src/zarr/testing/stateful.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import builtins
import functools
from collections.abc import Callable
from collections.abc import Callable, Iterable
from typing import Any, cast

import hypothesis.extra.numpy as npst
Expand All @@ -18,7 +18,12 @@

import zarr
from zarr import Array
from zarr.abc.store import Store
from zarr.abc.store import (
OffsetByteRequest,
RangeByteRequest,
Store,
SuffixByteRequest,
)
from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import Buffer, BufferPrototype, cpu, default_buffer_prototype
from zarr.core.sync import SyncMixin
Expand Down Expand Up @@ -460,7 +465,7 @@ def get(self, key: str, prototype: BufferPrototype) -> Buffer | None:
return self._sync(self.store.get(key, prototype=prototype))

def get_partial_values(
self, key_ranges: builtins.list[Any], prototype: BufferPrototype
self, key_ranges: Iterable[Any], prototype: BufferPrototype
) -> builtins.list[Buffer | None]:
return self._sync(self.store.get_partial_values(prototype=prototype, key_ranges=key_ranges))

Expand All @@ -476,6 +481,9 @@ def clear(self) -> None:
def exists(self, key: str) -> bool:
return self._sync(self.store.exists(key))

def getsize_prefix(self, prefix: str) -> int:
return self._sync(self.store.getsize_prefix(prefix))

def list_dir(self, prefix: str) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -555,7 +563,9 @@ def get_partial_values(self, data: DataObject) -> None:
key_ranges(keys=st.sampled_from(sorted(self.model.keys())), max_size=MAX_BINARY_SIZE)
)
note(f"(get partial) {key_range=}")
obs_maybe = self.store.get_partial_values(key_range, self.prototype)
# Pass a one-shot generator rather than a list: stores (and wrappers such
# as LoggingStore) must not exhaust the iterable before using it.
obs_maybe = self.store.get_partial_values((kr for kr in key_range), self.prototype)
observed = []

for obs in obs_maybe:
Expand All @@ -565,9 +575,23 @@ def get_partial_values(self, data: DataObject) -> None:
model_vals_ls = []

for key, byte_range in key_range:
start = byte_range.start
stop = byte_range.end
model_vals_ls.append(self.model[key][start:stop])
# Independently model each ByteRequest variant (do NOT reuse the
# store's _normalize_byte_range_index helper, so this stays an
# independent oracle). Bounds may exceed the value length.
value = self.model[key]
n = len(value)
if byte_range is None:
expected = value[:]
elif isinstance(byte_range, RangeByteRequest):
expected = value[byte_range.start : byte_range.end]
elif isinstance(byte_range, OffsetByteRequest):
expected = value[byte_range.offset :]
elif isinstance(byte_range, SuffixByteRequest):
# "last suffix bytes"; suffix > n means the whole value.
expected = value[max(0, n - byte_range.suffix) :]
else:
raise AssertionError(f"unexpected byte_range {byte_range!r}")
model_vals_ls.append(expected)

assert all(
obs == exp.to_bytes() for obs, exp in zip(observed, model_vals_ls, strict=True)
Expand Down Expand Up @@ -612,6 +636,21 @@ def exists(self, key: str) -> None:

assert self.store.exists(key) == (key in self.model)

@precondition(lambda self: len(self.model.keys()) > 0)
@rule(data=st.data())
def getsize_prefix(self, data: DataObject) -> None:
# Measure the size under the first path segment of some existing key.
# getsize_prefix(node) must count only keys under the directory "node/",
# not sibling keys that merely share the string prefix (e.g. measuring
# "a" must not include a sibling key "ab/...").
key = data.draw(st.sampled_from(sorted(self.model.keys())))
node = key.split("/")[0]
note(f"(getsize_prefix) {node=}")

observed = self.store.getsize_prefix(node)
expected = sum(len(value) for k, value in self.model.items() if k.startswith(node + "/"))
assert observed == expected, (observed, expected, node)

@invariant()
def check_paths_equal(self) -> None:
note("Checking that paths are equal")
Expand Down
23 changes: 18 additions & 5 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,16 @@ async def test_getsize(self, store: S, key: str, data: bytes) -> None:
async def test_getsize_prefix(self, store: S) -> None:
"""
Test the result of store.getsize_prefix().

Includes a sibling key ("cc/0") that shares the string prefix "c" but
belongs to a different directory: getsize_prefix("c") must not count it,
i.e. the prefix is matched as a directory ("c/...") not a raw substring.
"""
data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04")
keys = ["c/0/0", "c/0/1", "c/1/0", "c/1/1"]
keys_values = [(k, data_buf) for k in keys]
# Sibling directory sharing the "c" string prefix; must be excluded.
sibling_keys = ["cc/0"]
keys_values = [(k, data_buf) for k in keys + sibling_keys]
await store._set_many(keys_values)
expected = len(data_buf) * len(keys)
observed = await store.getsize_prefix("c")
Expand Down Expand Up @@ -370,20 +376,27 @@ async def test_get_partial_values(
for key, _ in key_ranges:
await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8")))

# read back just part of it
# read back just part of it. Pass key_ranges as a one-shot generator
# (a valid Iterable per the method signature) to ensure stores and
# wrappers do not exhaust the iterable before handing it to the backend.
observed_maybe = await store.get_partial_values(
prototype=default_buffer_prototype(), key_ranges=key_ranges
prototype=default_buffer_prototype(),
key_ranges=(kr for kr in key_ranges),
)

# One result must be returned per requested key range. Checking this
# explicitly guards against a store/wrapper exhausting the key_ranges
# iterable early and silently returning fewer (or no) results.
assert len(observed_maybe) == len(key_ranges)

observed: list[Buffer] = []
expected: list[Buffer] = []

for obs in observed_maybe:
assert obs is not None
observed.append(obs)

for idx in range(len(observed)):
key, byte_range = key_ranges[idx]
for key, byte_range in key_ranges:
result = await store.get(
key, prototype=default_buffer_prototype(), byte_range=byte_range
)
Expand Down
31 changes: 22 additions & 9 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from hypothesis.strategies import SearchStrategy

import zarr
from zarr.abc.store import RangeByteRequest, Store
from zarr.abc.store import (
ByteRequest,
OffsetByteRequest,
RangeByteRequest,
Store,
SuffixByteRequest,
)
from zarr.codecs.bytes import BytesCodec
from zarr.codecs.crc32c_ import Crc32cCodec
from zarr.codecs.sharding import SUBCHUNK_WRITE_ORDER, ShardingCodec, SubchunkWriteOrder
Expand Down Expand Up @@ -654,22 +660,29 @@ def predicate(value: tuple[Any, ...]) -> bool:

def key_ranges(
keys: SearchStrategy[str] = node_names, max_size: int = sys.maxsize
) -> SearchStrategy[list[tuple[str, RangeByteRequest]]]:
) -> SearchStrategy[list[tuple[str, ByteRequest | None]]]:
"""
Function to generate key_ranges strategy for get_partial_values()
returns list strategy w/ form::

[(key, (range_start, range_end)),
(key, (range_start, range_end)),...]
[(key, byte_request),
(key, byte_request),...]

where ``byte_request`` is ``None`` or any of the concrete ``ByteRequest``
subtypes. The bounds are drawn independently of each value's length, so the
offsets/suffixes routinely exceed the data and exercise the clamping logic
in ``_normalize_byte_range_index``.
"""

def make_request(start: int, length: int) -> RangeByteRequest:
def make_range(start: int, length: int) -> RangeByteRequest:
return RangeByteRequest(start, end=min(start + length, max_size))

byte_ranges = st.builds(
make_request,
start=st.integers(min_value=0, max_value=max_size),
length=st.integers(min_value=0, max_value=max_size),
bound = st.integers(min_value=0, max_value=max_size)
byte_ranges: SearchStrategy[ByteRequest | None] = st.one_of(
st.none(),
st.builds(make_range, start=bound, length=bound),
st.builds(OffsetByteRequest, offset=bound),
st.builds(SuffixByteRequest, suffix=bound),
)
key_tuple = st.tuples(keys, byte_ranges)
return st.lists(key_tuple, min_size=1, max_size=10)
Expand Down
Loading
Loading