diff --git a/changes/3638.feature.md b/changes/3638.feature.md new file mode 100644 index 0000000000..ad2276fd51 --- /dev/null +++ b/changes/3638.feature.md @@ -0,0 +1 @@ +Add methods for reading stored objects as bytes and JSON-decoded bytes to store classes. \ No newline at end of file diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 4b3edf78d1..0e98777ff5 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,11 +1,14 @@ from __future__ import annotations +import asyncio +import json from abc import ABC, abstractmethod -from asyncio import gather from dataclasses import dataclass from itertools import starmap from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from zarr.core.sync import sync + if TYPE_CHECKING: from collections.abc import AsyncGenerator, AsyncIterator, Iterable from types import TracebackType @@ -206,6 +209,211 @@ async def get( """ ... + async def get_bytes( + self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> bytes: + """ + Retrieve raw bytes from the store asynchronously. + + This is a convenience method that wraps ``get()`` and converts the result + to bytes. Use this when you need the raw byte content of a stored value. + + Parameters + ---------- + key : str + The key identifying the data to retrieve. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + get : Lower-level method that returns a Buffer object. + get_bytes : Synchronous version of this method. + get_json : Asynchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> await store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = await store.get_bytes("data", prototype=default_buffer_prototype()) + >>> print(data) + b'hello world' + """ + buffer = await self.get(key, prototype, byte_range) + if buffer is None: + raise FileNotFoundError(key) + return buffer.to_bytes() + + def get_bytes_sync( + self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> bytes: + """ + Retrieve raw bytes from the store synchronously. + + This is a synchronous wrapper around ``get_bytes()``. It should only + be called from non-async code. For async contexts, use ``get_bytes()`` + instead. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes()`` instead + to avoid blocking the event loop. + + See Also + -------- + get_bytes : Asynchronous version of this method. + get_json_sync : Synchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = MemoryStore() + >>> await store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = store.get_bytes_sync("data", prototype=default_buffer_prototype()) + >>> print(data) + b'hello world' + """ + + return sync(self.get_bytes(key, prototype=prototype, byte_range=byte_range)) + + async def get_json( + self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Any: + """ + Retrieve and parse JSON data from the store asynchronously. + + This is a convenience method that retrieves bytes from the store and + parses them as JSON. + + Parameters + ---------- + key : str + The key identifying the JSON data to retrieve. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + get_bytes : Method for retrieving raw bytes. + get_json_sync : Synchronous version of this method. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = await store.get_json("zarr.json", prototype=default_buffer_prototype()) + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + + return json.loads(await self.get_bytes(key, prototype=prototype, byte_range=byte_range)) + + def get_json_sync( + self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None + ) -> Any: + """ + Retrieve and parse JSON data from the store synchronously. + + This is a synchronous wrapper around ``get_json()``. It should only + be called from non-async code. For async contexts, use ``get_json()`` + instead. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype + The buffer prototype to use for reading the data. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json()`` instead + to avoid blocking the event loop. + + See Also + -------- + get_json : Asynchronous version of this method. + get_bytes_sync : Synchronous method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = MemoryStore() + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = store.get_json_sync("zarr.json", prototype=default_buffer_prototype()) + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + + return sync(self.get_json(key, prototype=prototype, byte_range=byte_range)) + @abstractmethod async def get_partial_values( self, @@ -278,7 +486,7 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: """ Insert multiple (key, value) pairs into storage. """ - await gather(*starmap(self.set, values)) + await asyncio.gather(*starmap(self.set, values)) @property def supports_consolidated_metadata(self) -> bool: diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f64da71bb4..9fb3f8b6ad 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -8,7 +8,7 @@ import sys import uuid from pathlib import Path -from typing import TYPE_CHECKING, BinaryIO, Literal, Self +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self from zarr.abc.store import ( ByteRequest, @@ -306,6 +306,236 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: except (FileNotFoundError, NotADirectoryError): pass + async def get_bytes( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the local store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + Store.get_bytes : Base implementation with full documentation. + get_bytes_sync : Synchronous version of this method. + + Examples + -------- + >>> store = await LocalStore.open("data") + >>> await store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for LocalStore + >>> data = await store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) + + def get_bytes_sync( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the local store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes()`` instead. + + See Also + -------- + Store.get_bytes_sync : Base implementation with full documentation. + get_bytes : Asynchronous version of this method. + + Examples + -------- + >>> store = LocalStore("data") + >>> store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for LocalStore + >>> data = store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) + + async def get_json( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the local store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + Store.get_json : Base implementation with full documentation. + get_json_sync : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = await LocalStore.open("data") + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for LocalStore + >>> data = await store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_json(key, prototype=prototype, byte_range=byte_range) + + def get_json_sync( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the local store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json()`` instead. + + See Also + -------- + Store.get_json_sync : Base implementation with full documentation. + get_json : Asynchronous version of this method. + get_bytes_sync : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = LocalStore("data") + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for LocalStore + >>> data = store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) + async def move(self, dest_root: Path | str) -> None: """ Move the store to another path. The old root directory is deleted. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 904be922d7..1568cc6736 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,7 +1,7 @@ from __future__ import annotations from logging import getLogger -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING, Any, Self from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu @@ -175,6 +175,236 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: for key in keys_unique: yield key + async def get_bytes( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the memory store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + Store.get_bytes : Base implementation with full documentation. + get_bytes_sync : Synchronous version of this method. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> await store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for MemoryStore + >>> data = await store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) + + def get_bytes_sync( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the memory store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes()`` instead. + + See Also + -------- + Store.get_bytes_sync : Base implementation with full documentation. + get_bytes : Asynchronous version of this method. + + Examples + -------- + >>> store = MemoryStore() + >>> store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for MemoryStore + >>> data = store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) + + async def get_json( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the memory store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + Store.get_json : Base implementation with full documentation. + get_json_sync : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for MemoryStore + >>> data = await store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return await super().get_json(key, prototype=prototype, byte_range=byte_range) + + def get_json_sync( + self, + key: str = "", + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the memory store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json()`` instead. + + See Also + -------- + Store.get_json_sync : Base implementation with full documentation. + get_json : Asynchronous version of this method. + get_bytes_sync : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = MemoryStore() + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for MemoryStore + >>> data = store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = default_buffer_prototype() + return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) + class GpuMemoryStore(MemoryStore): """ diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index ad3b80da41..a56061ae12 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json import pickle from abc import abstractmethod from typing import TYPE_CHECKING, Generic, TypeVar @@ -23,7 +24,7 @@ SuffixByteRequest, ) from zarr.core.buffer import Buffer, default_buffer_prototype -from zarr.core.sync import _collect_aiterator +from zarr.core.sync import _collect_aiterator, sync from zarr.storage._utils import _normalize_byte_range_index from zarr.testing.utils import assert_bytes_equal @@ -526,6 +527,46 @@ async def test_set_if_not_exists(self, store: S) -> None: result = await store.get("k2", default_buffer_prototype()) assert result == new + async def test_get_bytes(self, store: S) -> None: + """ + Test that the get_bytes method reads bytes. + """ + data = b"hello world" + key = "zarr.json" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + assert await store.get_bytes(key, prototype=default_buffer_prototype()) == data + with pytest.raises(FileNotFoundError): + await store.get_bytes("nonexistent_key", prototype=default_buffer_prototype()) + + def test_get_bytes_sync(self, store: S) -> None: + """ + Test that the get_bytes_sync method reads bytes. + """ + data = b"hello world" + key = "zarr.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + assert store.get_bytes_sync(key, prototype=default_buffer_prototype()) == data + + async def test_get_json(self, store: S) -> None: + """ + Test that the get_json method reads json. + """ + data = {"foo": "bar"} + data_bytes = json.dumps(data).encode("utf-8") + key = "zarr.json" + await self.set(store, key, self.buffer_cls.from_bytes(data_bytes)) + assert await store.get_json(key, prototype=default_buffer_prototype()) == data + + def test_get_json_sync(self, store: S) -> None: + """ + Test that the get_json method reads json. + """ + data = {"foo": "bar"} + data_bytes = json.dumps(data).encode("utf-8") + key = "zarr.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) + assert store.get_json_sync(key, prototype=default_buffer_prototype()) == data + class LatencyStore(WrapperStore[Store]): """ diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 6756bc83d9..fa4bc7cfc0 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -1,7 +1,9 @@ from __future__ import annotations +import json import pathlib import re +from typing import TYPE_CHECKING import numpy as np import pytest @@ -9,11 +11,15 @@ import zarr from zarr import create_array from zarr.core.buffer import Buffer, cpu +from zarr.core.sync import sync from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests from zarr.testing.utils import assert_bytes_equal +if TYPE_CHECKING: + from zarr.core.buffer import BufferPrototype + class TestLocalStore(StoreTests[LocalStore, cpu.Buffer]): store_cls = LocalStore @@ -108,6 +114,54 @@ async def test_move( ): await store2.move(destination) + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_bytes_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes works with prototype=None.""" + data = b"hello world" + key = "test_key" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + + result = await store.get_bytes(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_bytes_sync_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes_sync works with prototype=None.""" + data = b"hello world" + key = "test_key" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + + result = store.get_bytes_sync(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_json_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) + + result = await store.get_json(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_json_sync_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json_sync works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) + + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data + @pytest.mark.parametrize("exclusive", [True, False]) def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None: diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 29fa9b2964..96b7fe9845 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import re from typing import TYPE_CHECKING, Any @@ -9,12 +10,14 @@ import zarr from zarr.core.buffer import Buffer, cpu, gpu +from zarr.core.sync import sync from zarr.errors import ZarrUserWarning from zarr.storage import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests from zarr.testing.utils import gpu_test if TYPE_CHECKING: + from zarr.core.buffer import BufferPrototype from zarr.core.common import ZarrFormat @@ -76,6 +79,54 @@ async def test_deterministic_size( np.testing.assert_array_equal(a[:3], 1) np.testing.assert_array_equal(a[3:], 0) + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_bytes_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes works with prototype=None.""" + data = b"hello world" + key = "test_key" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + + result = await store.get_bytes(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_bytes_sync_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes_sync works with prototype=None.""" + data = b"hello world" + key = "test_key" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + + result = store.get_bytes_sync(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_json_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) + + result = await store.get_json(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_json_sync_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json_sync works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) + + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data + # TODO: fix this warning @pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning")