diff --git a/changes/4074.bugfix.md b/changes/4074.bugfix.md new file mode 100644 index 0000000000..d55a52b887 --- /dev/null +++ b/changes/4074.bugfix.md @@ -0,0 +1,7 @@ +Fixed several storage and codec bugs: + +- Reading a value with a `SuffixByteRequest` larger than the value now correctly returns the whole value (matching HTTP `bytes=-N` suffix-range semantics), instead of silently returning incorrect data for `MemoryStore`. +- `LoggingStore.get_partial_values` and `FsspecStore.get_partial_values` no longer return empty results when `key_ranges` is passed as a one-shot iterable (e.g. a generator). +- `Store.getsize_prefix` no longer over-counts sibling keys that merely share a string prefix (e.g. `getsize_prefix("foo")` no longer includes keys under `foobar/`). +- `ZipStore.close()` no longer raises `AttributeError` when the store was created but never opened (including when used as a context manager without any I/O). +- `codecs_from_list` now raises a descriptive `TypeError` when a `BytesBytesCodec` immediately follows an `ArrayArrayCodec`, instead of a misleading "Required ArrayBytesCodec was not found" `ValueError`. diff --git a/pyproject.toml b/pyproject.toml index 9b372192e9..6f6c7265a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -393,7 +393,7 @@ show_error_code_links = true show_error_context = true strict = true warn_unreachable = true -enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool", "truthy-iterable"] [[tool.mypy.overrides]] module = [ diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 304d0cddb5..7c187594df 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -536,6 +536,8 @@ async def getsize_prefix(self, prefix: str) -> int: from zarr.core.common import concurrent_map from zarr.core.config import config + if prefix != "" and not prefix.endswith("/"): + prefix += "/" keys = [(x,) async for x in self.list_prefix(prefix)] limit = config.get("async.concurrency") sizes = await concurrent_map(keys, self.getsize, limit=limit) diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 5c26681d6b..032703fc03 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -679,6 +679,7 @@ def codecs_from_list( "must be preceded by either another BytesBytesCodec, or an ArrayBytesCodec. " f"Got {type(prev_codec)} instead." ) + raise TypeError(msg) bytes_bytes += (cur_codec,) else: raise TypeError diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 29201a6fee..89d788af1a 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -424,30 +424,31 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - if key_ranges: - # _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest. - key_ranges = list(key_ranges) - paths: list[str] = [] - starts: list[int | None] = [] - stops: list[int | None] = [] - for key, byte_range in key_ranges: - paths.append(_dereference_path(self.path, key)) - if byte_range is None: - starts.append(None) - stops.append(None) - elif isinstance(byte_range, RangeByteRequest): - starts.append(byte_range.start) - stops.append(byte_range.end) - elif isinstance(byte_range, OffsetByteRequest): - starts.append(byte_range.offset) - stops.append(None) - elif isinstance(byte_range, SuffixByteRequest): - starts.append(-byte_range.suffix) - stops.append(None) - else: - raise ValueError(f"Unexpected byte_range, got {byte_range}.") - else: + # Materialise first: key_ranges may be a one-shot iterable, so a bare + # truthiness check (e.g. `if key_ranges`) would be unreliable for an + # empty generator. _cat_ranges also expects lists of paths/starts/stops. + key_ranges = list(key_ranges) + if not key_ranges: return [] + paths: list[str] = [] + starts: list[int | None] = [] + stops: list[int | None] = [] + for key, byte_range in key_ranges: + paths.append(_dereference_path(self.path, key)) + if byte_range is None: + starts.append(None) + stops.append(None) + elif isinstance(byte_range, RangeByteRequest): + starts.append(byte_range.start) + stops.append(byte_range.end) + elif isinstance(byte_range, OffsetByteRequest): + starts.append(byte_range.offset) + stops.append(None) + elif isinstance(byte_range, SuffixByteRequest): + starts.append(-byte_range.suffix) + stops.append(None) + else: + raise ValueError(f"Unexpected byte_range, got {byte_range}.") # TODO: expectations for exceptions or missing keys? res = await self.fs._cat_ranges(paths, starts, stops, on_error="return") # the following is an s3-specific condition we probably don't want to leak diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index 5de300c144..c6f58ccd61 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -179,6 +179,7 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + key_ranges = list(key_ranges) keys = ",".join([k[0] for k in key_ranges]) with self.log(keys): return await self._store.get_partial_values(prototype=prototype, key_ranges=key_ranges) diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 1f8e9b0a29..b100f862cf 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -153,7 +153,7 @@ def _normalize_byte_range_index(data: Buffer, byte_range: ByteRequest | None) -> start = byte_range.offset stop = len(data) + 1 elif isinstance(byte_range, SuffixByteRequest): - start = len(data) - byte_range.suffix + start = max(0, len(data) - byte_range.suffix) stop = len(data) + 1 else: raise ValueError(f"Unexpected byte_range, got {byte_range}.") diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index 897797e999..430b0c3e2a 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -120,6 +120,8 @@ def __setstate__(self, state: dict[str, Any]) -> None: def close(self) -> None: # docstring inherited + if not self._is_open: + return super().close() with self._lock: self._zf.close() diff --git a/src/zarr/testing/stateful.py b/src/zarr/testing/stateful.py index d6c43f4ecc..9817ebd618 100644 --- a/src/zarr/testing/stateful.py +++ b/src/zarr/testing/stateful.py @@ -1,6 +1,6 @@ import builtins import functools -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any, cast import hypothesis.extra.numpy as npst @@ -18,7 +18,12 @@ import zarr from zarr import Array -from zarr.abc.store import Store +from zarr.abc.store import ( + OffsetByteRequest, + RangeByteRequest, + Store, + SuffixByteRequest, +) from zarr.codecs.bytes import BytesCodec from zarr.core.buffer import Buffer, BufferPrototype, cpu, default_buffer_prototype from zarr.core.sync import SyncMixin @@ -460,7 +465,7 @@ def get(self, key: str, prototype: BufferPrototype) -> Buffer | None: return self._sync(self.store.get(key, prototype=prototype)) def get_partial_values( - self, key_ranges: builtins.list[Any], prototype: BufferPrototype + self, key_ranges: Iterable[Any], prototype: BufferPrototype ) -> builtins.list[Buffer | None]: return self._sync(self.store.get_partial_values(prototype=prototype, key_ranges=key_ranges)) @@ -476,6 +481,9 @@ def clear(self) -> None: def exists(self, key: str) -> bool: return self._sync(self.store.exists(key)) + def getsize_prefix(self, prefix: str) -> int: + return self._sync(self.store.getsize_prefix(prefix)) + def list_dir(self, prefix: str) -> None: raise NotImplementedError @@ -555,7 +563,9 @@ def get_partial_values(self, data: DataObject) -> None: key_ranges(keys=st.sampled_from(sorted(self.model.keys())), max_size=MAX_BINARY_SIZE) ) note(f"(get partial) {key_range=}") - obs_maybe = self.store.get_partial_values(key_range, self.prototype) + # Pass a one-shot generator rather than a list: stores (and wrappers such + # as LoggingStore) must not exhaust the iterable before using it. + obs_maybe = self.store.get_partial_values((kr for kr in key_range), self.prototype) observed = [] for obs in obs_maybe: @@ -565,9 +575,23 @@ def get_partial_values(self, data: DataObject) -> None: model_vals_ls = [] for key, byte_range in key_range: - start = byte_range.start - stop = byte_range.end - model_vals_ls.append(self.model[key][start:stop]) + # Independently model each ByteRequest variant (do NOT reuse the + # store's _normalize_byte_range_index helper, so this stays an + # independent oracle). Bounds may exceed the value length. + value = self.model[key] + n = len(value) + if byte_range is None: + expected = value[:] + elif isinstance(byte_range, RangeByteRequest): + expected = value[byte_range.start : byte_range.end] + elif isinstance(byte_range, OffsetByteRequest): + expected = value[byte_range.offset :] + elif isinstance(byte_range, SuffixByteRequest): + # "last suffix bytes"; suffix > n means the whole value. + expected = value[max(0, n - byte_range.suffix) :] + else: + raise AssertionError(f"unexpected byte_range {byte_range!r}") + model_vals_ls.append(expected) assert all( obs == exp.to_bytes() for obs, exp in zip(observed, model_vals_ls, strict=True) @@ -612,6 +636,21 @@ def exists(self, key: str) -> None: assert self.store.exists(key) == (key in self.model) + @precondition(lambda self: len(self.model.keys()) > 0) + @rule(data=st.data()) + def getsize_prefix(self, data: DataObject) -> None: + # Measure the size under the first path segment of some existing key. + # getsize_prefix(node) must count only keys under the directory "node/", + # not sibling keys that merely share the string prefix (e.g. measuring + # "a" must not include a sibling key "ab/..."). + key = data.draw(st.sampled_from(sorted(self.model.keys()))) + node = key.split("/")[0] + note(f"(getsize_prefix) {node=}") + + observed = self.store.getsize_prefix(node) + expected = sum(len(value) for k, value in self.model.items() if k.startswith(node + "/")) + assert observed == expected, (observed, expected, node) + @invariant() def check_paths_equal(self) -> None: note("Checking that paths are equal") diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 81024c85c8..11ceeee83a 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -299,10 +299,16 @@ async def test_getsize(self, store: S, key: str, data: bytes) -> None: async def test_getsize_prefix(self, store: S) -> None: """ Test the result of store.getsize_prefix(). + + Includes a sibling key ("cc/0") that shares the string prefix "c" but + belongs to a different directory: getsize_prefix("c") must not count it, + i.e. the prefix is matched as a directory ("c/...") not a raw substring. """ data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") keys = ["c/0/0", "c/0/1", "c/1/0", "c/1/1"] - keys_values = [(k, data_buf) for k in keys] + # Sibling directory sharing the "c" string prefix; must be excluded. + sibling_keys = ["cc/0"] + keys_values = [(k, data_buf) for k in keys + sibling_keys] await store._set_many(keys_values) expected = len(data_buf) * len(keys) observed = await store.getsize_prefix("c") @@ -370,11 +376,19 @@ async def test_get_partial_values( for key, _ in key_ranges: await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) - # read back just part of it + # read back just part of it. Pass key_ranges as a one-shot generator + # (a valid Iterable per the method signature) to ensure stores and + # wrappers do not exhaust the iterable before handing it to the backend. observed_maybe = await store.get_partial_values( - prototype=default_buffer_prototype(), key_ranges=key_ranges + prototype=default_buffer_prototype(), + key_ranges=(kr for kr in key_ranges), ) + # One result must be returned per requested key range. Checking this + # explicitly guards against a store/wrapper exhausting the key_ranges + # iterable early and silently returning fewer (or no) results. + assert len(observed_maybe) == len(key_ranges) + observed: list[Buffer] = [] expected: list[Buffer] = [] @@ -382,8 +396,7 @@ async def test_get_partial_values( assert obs is not None observed.append(obs) - for idx in range(len(observed)): - key, byte_range = key_ranges[idx] + for key, byte_range in key_ranges: result = await store.get( key, prototype=default_buffer_prototype(), byte_range=byte_range ) diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 7d6556a359..0ef1ba99bb 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -11,7 +11,13 @@ from hypothesis.strategies import SearchStrategy import zarr -from zarr.abc.store import RangeByteRequest, Store +from zarr.abc.store import ( + ByteRequest, + OffsetByteRequest, + RangeByteRequest, + Store, + SuffixByteRequest, +) from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec from zarr.codecs.sharding import SUBCHUNK_WRITE_ORDER, ShardingCodec, SubchunkWriteOrder @@ -654,22 +660,29 @@ def predicate(value: tuple[Any, ...]) -> bool: def key_ranges( keys: SearchStrategy[str] = node_names, max_size: int = sys.maxsize -) -> SearchStrategy[list[tuple[str, RangeByteRequest]]]: +) -> SearchStrategy[list[tuple[str, ByteRequest | None]]]: """ Function to generate key_ranges strategy for get_partial_values() returns list strategy w/ form:: - [(key, (range_start, range_end)), - (key, (range_start, range_end)),...] + [(key, byte_request), + (key, byte_request),...] + + where ``byte_request`` is ``None`` or any of the concrete ``ByteRequest`` + subtypes. The bounds are drawn independently of each value's length, so the + offsets/suffixes routinely exceed the data and exercise the clamping logic + in ``_normalize_byte_range_index``. """ - def make_request(start: int, length: int) -> RangeByteRequest: + def make_range(start: int, length: int) -> RangeByteRequest: return RangeByteRequest(start, end=min(start + length, max_size)) - byte_ranges = st.builds( - make_request, - start=st.integers(min_value=0, max_value=max_size), - length=st.integers(min_value=0, max_value=max_size), + bound = st.integers(min_value=0, max_value=max_size) + byte_ranges: SearchStrategy[ByteRequest | None] = st.one_of( + st.none(), + st.builds(make_range, start=bound, length=bound), + st.builds(OffsetByteRequest, offset=bound), + st.builds(SuffixByteRequest, suffix=bound), ) key_tuple = st.tuples(keys, byte_ranges) return st.lists(key_tuple, min_size=1, max_size=10) diff --git a/tests/test_codec_pipeline.py b/tests/test_codec_pipeline.py index fa41c2867b..4d596164db 100644 --- a/tests/test_codec_pipeline.py +++ b/tests/test_codec_pipeline.py @@ -1,15 +1,28 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import pytest +pytest.importorskip("hypothesis") + +import hypothesis.strategies as st +from hypothesis import given + import zarr -from zarr.codecs import BytesCodec, CastValue +from zarr.codecs import BytesCodec, CastValue, GzipCodec, TransposeCodec from zarr.core.array import _get_chunk_spec from zarr.core.buffer.core import default_buffer_prototype +from zarr.core.codec_pipeline import codecs_from_list from zarr.core.indexing import BasicIndexer from zarr.storage import MemoryStore +if TYPE_CHECKING: + from collections.abc import Callable + + from zarr.abc.codec import Codec + @pytest.mark.parametrize( ("write_slice", "read_slice", "expected_statuses"), @@ -120,3 +133,65 @@ def test_codec_pipeline_threads_dtype_through_evolve(source_dtype: str, target_d ) arr[:] = np.asarray([0, 1, 2, 3], dtype=source_dtype) np.testing.assert_array_equal(arr[:], np.asarray([0, 1, 2, 3], dtype=source_dtype)) + + +# Property-based check of codecs_from_list ordering validation. +# +# Valid codec orderings are exactly: (ArrayArrayCodec)* (ArrayBytesCodec) +# (BytesBytesCodec)*. codecs_from_list walks adjacent pairs and must raise +# TypeError the moment a codec appears in a structurally invalid position -- +# notably, a BytesBytesCodec immediately following an ArrayArrayCodec with no +# ArrayBytesCodec in between (which previously built an error message but never +# raised it, falling through to an unrelated ValueError instead). +_AA = "AA" # ArrayArrayCodec -> TransposeCodec +_AB = "AB" # ArrayBytesCodec -> BytesCodec +_BB = "BB" # BytesBytesCodec -> GzipCodec + +_CODEC_FACTORY: dict[str, Callable[[], Codec]] = { + _AA: lambda: TransposeCodec(order=(0, 1)), + _AB: BytesCodec, + _BB: GzipCodec, +} + + +def _expected_codec_order_outcome(labels: list[str]) -> str: + """Independently predict codecs_from_list's outcome: 'TypeError', + 'ValueError' or 'ok', mirroring its left-to-right scan and the order in + which it checks ordering violations (TypeError) vs. the ArrayBytes-count + constraints (ValueError).""" + prev = None + seen_array_bytes = False + for cur in labels: + if cur == _AA: + if prev in (_AB, _BB): + return "TypeError" + elif cur == _AB: + if prev == _BB: + return "TypeError" + if seen_array_bytes: + return "ValueError" # two ArrayBytesCodecs + seen_array_bytes = True + else: # _BB + if prev == _AA: + return "TypeError" + prev = cur + if not seen_array_bytes: + return "ValueError" # Required ArrayBytesCodec was not found + return "ok" + + +@given(labels=st.lists(st.sampled_from([_AA, _AB, _BB]), min_size=1, max_size=5)) +def test_codecs_from_list_outcome_matches_order_rules(labels: list[str]) -> None: + codecs = [_CODEC_FACTORY[label]() for label in labels] + expected = _expected_codec_order_outcome(labels) + if expected == "TypeError": + with pytest.raises(TypeError): + codecs_from_list(codecs) + elif expected == "ValueError": + with pytest.raises(ValueError): + codecs_from_list(codecs) + else: + # Valid ordering: must classify without raising. + aa, _ab, bb = codecs_from_list(codecs) + assert labels.count(_AA) == len(aa) + assert labels.count(_BB) == len(bb) diff --git a/tests/test_store/test_utils.py b/tests/test_store/test_utils.py index b1934e7eae..291526fab8 100644 --- a/tests/test_store/test_utils.py +++ b/tests/test_store/test_utils.py @@ -5,7 +5,9 @@ import pytest -from zarr.storage._utils import ParsedStoreUrl, parse_store_url +from zarr.abc.store import SuffixByteRequest +from zarr.core.buffer.core import default_buffer_prototype +from zarr.storage._utils import ParsedStoreUrl, _normalize_byte_range_index, parse_store_url class TestParseStoreUrl: @@ -95,3 +97,32 @@ def test_drive_letter_not_special_on_non_windows(self, url: str) -> None: result = parse_store_url(url) # urlparse interprets the drive letter as a scheme assert result.scheme == "c" + + +class TestNormalizeByteRangeIndex: + """Tests for _normalize_byte_range_index.""" + + def test_suffix_larger_than_data_returns_all_bytes(self) -> None: + """Regression: SuffixByteRequest with suffix > len(data) must not produce a + negative start index that causes numpy to return fewer bytes than available.""" + prototype = default_buffer_prototype() + data = prototype.buffer.from_bytes(b"hello") # 5 bytes + byte_range = SuffixByteRequest(suffix=7) + start, stop = _normalize_byte_range_index(data, byte_range) + assert start == 0, f"start should be 0 (clamped), got {start}" + result = data[start:stop] + assert len(result) == 5, f"expected all 5 bytes, got {len(result)}" + + def test_suffix_exact_length(self) -> None: + """SuffixByteRequest with suffix == len(data) returns all bytes.""" + prototype = default_buffer_prototype() + data = prototype.buffer.from_bytes(b"hello") + start, _stop = _normalize_byte_range_index(data, SuffixByteRequest(suffix=5)) + assert start == 0 + + def test_suffix_shorter_than_data(self) -> None: + """SuffixByteRequest with suffix < len(data) returns the last n bytes.""" + prototype = default_buffer_prototype() + data = prototype.buffer.from_bytes(b"hello") + start, _stop = _normalize_byte_range_index(data, SuffixByteRequest(suffix=3)) + assert start == 2 diff --git a/tests/test_store/test_zip.py b/tests/test_store/test_zip.py index be51bcedcb..ed69114b51 100644 --- a/tests/test_store/test_zip.py +++ b/tests/test_store/test_zip.py @@ -8,11 +8,20 @@ import numpy as np import pytest +from hypothesis import settings +from hypothesis.stateful import ( + RuleBasedStateMachine, + initialize, + precondition, + rule, + run_state_machine_as_test, +) import zarr from zarr import create_array from zarr.core.buffer import Buffer, cpu, default_buffer_prototype from zarr.core.group import Group +from zarr.core.sync import sync from zarr.storage import ZipStore from zarr.testing.store import StoreTests @@ -177,3 +186,66 @@ async def test_move(self, tmp_path: Path) -> None: assert destination.exists() assert not origin.exists() assert np.array_equal(array[...], np.arange(10)) + + +class ZipStoreLifecycleMachine(RuleBasedStateMachine): + """Drive a ZipStore through construct / open / write / close transitions. + + Invariant under test: a constructed ZipStore can always be closed without + raising, regardless of whether it was ever opened or did any I/O. This is a + property-based generalization of the former example-based regression tests + for ZipStore.close() being called on a never-opened store (which raised + AttributeError because ``_lock`` is created lazily in ``_sync_open``). + """ + + def __init__(self, tmp_path: Path) -> None: + super().__init__() + self._tmp_path = tmp_path + self._counter = 0 + self.store: ZipStore | None = None + self._opened = False + + @initialize() + def start(self) -> None: + self.store = None + self._opened = False + + @precondition(lambda self: self.store is None) + @rule() + def construct(self) -> None: + # Fresh path each time so mode="w" never clobbers a closed archive. + self._counter += 1 + self.store = ZipStore(self._tmp_path / f"s{self._counter}.zip", mode="w") + self._opened = False + + @precondition(lambda self: self.store is not None and not self._opened) + @rule() + def open(self) -> None: + assert self.store is not None + self.store._sync_open() + self._opened = True + + @precondition(lambda self: self.store is not None and not self._opened) + @rule() + def write(self) -> None: + assert self.store is not None + # store.set auto-opens the store. + sync(self.store.set("a", cpu.Buffer.from_bytes(b"hi"))) + self._opened = True + + @precondition(lambda self: self.store is not None) + @rule() + def close(self) -> None: + assert self.store is not None + # The property under test: close() must never raise, even with no + # prior open or I/O. + self.store.close() + self.store = None + self._opened = False + + +def test_zipstore_close_lifecycle(tmp_path: Path) -> None: + run_state_machine_as_test( # type: ignore[no-untyped-call] + lambda: ZipStoreLifecycleMachine(tmp_path), + settings=settings(max_examples=50, deadline=None), + )