diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index fdad272b94c..45258b2c69d 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -46,8 +46,8 @@ T = TypeVar('T', bound='IOMixinArray') T_doc = TypeVar('T_doc', bound=BaseDoc) -ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'} -SINGLE_PROTOCOLS = {'pickle', 'protobuf'} +ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array', 'json-array'} +SINGLE_PROTOCOLS = {'pickle', 'protobuf', 'json'} ALLOWED_PROTOCOLS = ARRAY_PROTOCOLS.union(SINGLE_PROTOCOLS) ALLOWED_COMPRESSIONS = {'lz4', 'bz2', 'lzma', 'zlib', 'gzip'} @@ -180,6 +180,8 @@ def _write_bytes( f.write(self.to_protobuf().SerializePartialToString()) elif protocol == 'pickle-array': f.write(pickle.dumps(self)) + elif protocol == 'json-array': + f.write(self.to_json()) elif protocol in SINGLE_PROTOCOLS: f.write( b''.join( @@ -575,7 +577,11 @@ def _load_binary_all( else: d = fp.read() - if protocol is not None and protocol in ('pickle-array', 'protobuf-array'): + if protocol is not None and protocol in ( + 'pickle-array', + 'protobuf-array', + 'json-array', + ): if _get_compress_ctx(algorithm=compress) is not None: d = _decompress_bytes(d, algorithm=compress) compress = None @@ -590,6 +596,9 @@ def _load_binary_all( elif protocol is not None and protocol == 'pickle-array': return pickle.loads(d) + elif protocol is not None and protocol == 'json-array': + return cls.from_json(d) + # Binary format for streaming case else: from rich import filesize diff --git a/tests/units/array/test_array_save_load.py b/tests/units/array/test_array_save_load.py index 1a632673d15..a56ad13064a 100644 --- a/tests/units/array/test_array_save_load.py +++ b/tests/units/array/test_array_save_load.py @@ -16,7 +16,7 @@ class MyDoc(BaseDoc): @pytest.mark.slow @pytest.mark.parametrize( - 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] + 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle', 'json-array'] ) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True]) @@ -52,7 +52,7 @@ def test_array_save_load_binary(protocol, compress, tmp_path, show_progress): @pytest.mark.slow @pytest.mark.parametrize( - 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle'] + 'protocol', ['pickle-array', 'protobuf-array', 'protobuf', 'pickle', 'json-array'] ) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) @pytest.mark.parametrize('show_progress', [False, True])