Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions docarray/array/mixins/io/binary.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import io
import os.path
import pickle
from contextlib import nullcontext
from typing import Union, BinaryIO, TYPE_CHECKING, Type, Optional

from ....helper import random_uuid, __windows__, get_compress_ctx, decompress_bytes

if TYPE_CHECKING:
from ....types import T
from ....proto.docarray_pb2 import DocumentArrayProto


class BinaryIOMixin:
Expand All @@ -16,7 +18,7 @@ class BinaryIOMixin:
def load_binary(
cls: Type['T'],
file: Union[str, BinaryIO, bytes],
protocol: Union[str, int] = 'protobuf',
protocol: str = 'pickle-once',
compress: Optional[str] = None,
) -> 'T':
"""Load array elements from a LZ4-compressed binary file.
Expand All @@ -43,28 +45,36 @@ def load_binary(
d = decompress_bytes(d, algorithm=compress)
compress = None

_len = len(random_uuid().bytes)
_binary_delimiter = d[:_len] # first get delimiter
da = cls()
da.extend(
Document.from_bytes(od, protocol=protocol, compress=compress)
for od in d[_len:].split(_binary_delimiter)
)
return da
if protocol == 'protobuf-once':
from ....proto.docarray_pb2 import DocumentArrayProto

dap = DocumentArrayProto()
dap.ParseFromString(d)

return cls.from_protobuf(dap)
elif protocol == 'pickle-once':
return pickle.loads(d)
else:
_len = len(random_uuid().bytes)
_binary_delimiter = d[:_len] # first get delimiter
return cls(
Document.from_bytes(od, protocol=protocol, compress=compress)
for od in d[_len:].split(_binary_delimiter)
)

@classmethod
def from_bytes(
cls: Type['T'],
data: bytes,
protocol: Union[str, int] = 'protobuf',
protocol: str = 'pickle-once',
compress: Optional[str] = None,
) -> 'T':
return cls.load_binary(data, protocol=protocol, compress=compress)

def save_binary(
self,
file: Union[str, BinaryIO],
protocol: Union[str, int] = 'protobuf',
protocol: str = 'pickle-once',
compress: Optional[str] = None,
) -> None:
"""Save array elements into a LZ4 compressed binary file.
Expand All @@ -88,7 +98,7 @@ def save_binary(

def to_bytes(
self,
protocol: Union[str, int] = 'protobuf',
protocol: str = 'pickle-once',
compress: Optional[str] = None,
_file_ctx: Optional[BinaryIO] = None,
) -> bytes:
Expand All @@ -111,11 +121,30 @@ def to_bytes(
fc = f
compress = None
with fc:
for d in self:
f.write(_binary_delimiter)
f.write(d.to_bytes(protocol=protocol, compress=compress))
if protocol == 'protobuf-once':
f.write(self.to_protobuf().SerializePartialToString())
elif protocol == 'pickle-once':
f.write(pickle.dumps(self))
else:
for d in self:
f.write(_binary_delimiter)
f.write(d.to_bytes(protocol=protocol, compress=compress))
if not _file_ctx:
return bf.getvalue()

def to_protobuf(self) -> 'DocumentArrayProto':
from ....proto.docarray_pb2 import DocumentArrayProto

dap = DocumentArrayProto()
for d in self:
dap.docs.append(d.to_protobuf())
return dap

@classmethod
def from_protobuf(cls: Type['T'], pb_msg: 'DocumentArrayProto') -> 'T':
from .... import Document

return cls(Document.from_protobuf(od) for od in pb_msg.docs)

def __bytes__(self):
return self.to_bytes()
8 changes: 4 additions & 4 deletions docarray/array/mixins/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ParallelMixin:
def apply(
self: 'T',
func: Callable[['Document'], 'Document'],
backend: str = 'process',
backend: str = 'thread',
num_worker: Optional[int] = None,
) -> 'T':
"""Apply each element in itself with ``func``, return itself after modified.
Expand Down Expand Up @@ -53,7 +53,7 @@ def apply(self: 'T', *args, **kwargs) -> 'T':
def map(
self,
func: Callable[['Document'], 'T'],
backend: str = 'process',
backend: str = 'thread',
num_worker: Optional[int] = None,
) -> Generator['T', None, None]:
"""Return an iterator that applies function to every **element** of iterable in parallel, yielding the results.
Expand Down Expand Up @@ -88,7 +88,7 @@ def apply_batch(
self: 'T',
func: Callable[['DocumentArray'], 'DocumentArray'],
batch_size: int,
backend: str = 'process',
backend: str = 'thread',
num_worker: Optional[int] = None,
shuffle: bool = False,
) -> 'T':
Expand Down Expand Up @@ -129,7 +129,7 @@ def map_batch(
self: 'T_DA',
func: Callable[['DocumentArray'], 'T'],
batch_size: int,
backend: str = 'process',
backend: str = 'thread',
num_worker: Optional[int] = None,
shuffle: bool = False,
) -> Generator['T', None, None]:
Expand Down
10 changes: 5 additions & 5 deletions docarray/document/mixins/porting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def to_dict(self):
)

def to_bytes(
self, protocol: Union[str, int] = 'protobuf', compress: Optional[str] = None
self, protocol: str = 'pickle', compress: Optional[str] = None
) -> bytes:
if isinstance(protocol, int):
bstr = pickle.dumps(self, protocol=protocol)
if protocol == 'pickle':
bstr = pickle.dumps(self)
elif protocol == 'protobuf':
bstr = self.to_protobuf().SerializePartialToString()
else:
Expand All @@ -51,11 +51,11 @@ def to_bytes(
def from_bytes(
cls: Type['T'],
data: bytes,
protocol: Union[str, int] = 'protobuf',
protocol: str = 'pickle',
compress: Optional[str] = None,
) -> 'T':
bstr = decompress_bytes(data, algorithm=compress)
if isinstance(protocol, int):
if protocol == 'pickle':
d = pickle.loads(bstr)
elif protocol == 'protobuf':
from ...proto.docarray_pb2 import DocumentProto
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/array/mixins/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def test_from_to_bytes(da_cls):
da.blobs = [[1, 2], [2, 1]]
da[0].tags = {'hello': 'world'}
da2 = da_cls.load_binary(bytes(da))
assert da2.blobs.tolist() == [[1, 2], [2, 1]]
assert da2.embeddings.tolist() == [[1, 2, 3], [4, 5, 6]]
assert da2.blobs == [[1, 2], [2, 1]]
assert da2.embeddings == [[1, 2, 3], [4, 5, 6]]
assert da2[0].tags == {'hello': 'world'}
assert da2[1].tags == {}

Expand Down
13 changes: 11 additions & 2 deletions tests/unit/array/test_from_to_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@


@pytest.mark.parametrize('target_da', [DocumentArray.empty(100), random_docs(100)])
@pytest.mark.parametrize('protocol', ['protobuf', 0, 1, 2, 3, 4])
@pytest.mark.parametrize(
'protocol', ['protobuf', 'protobuf-once', 'pickle', 'pickle-once']
)
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
def test_to_from_bytes(target_da, protocol, compress):
bstr = target_da.to_bytes(protocol=protocol, compress=compress)
Expand All @@ -15,7 +17,9 @@ def test_to_from_bytes(target_da, protocol, compress):


@pytest.mark.parametrize('target_da', [DocumentArray.empty(100), random_docs(100)])
@pytest.mark.parametrize('protocol', ['protobuf', 0, 1, 2, 3, 4])
@pytest.mark.parametrize(
'protocol', ['protobuf', 'protobuf-once', 'pickle', 'pickle-once']
)
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
def test_save_bytes(target_da, protocol, compress, tmpfile):
target_da.save_binary(tmpfile, protocol=protocol, compress=compress)
Expand All @@ -28,3 +32,8 @@ def test_save_bytes(target_da, protocol, compress, tmpfile):
DocumentArray.load_binary(str(tmpfile), protocol=protocol, compress=compress)
with open(tmpfile, 'rb') as fp:
DocumentArray.load_binary(fp, protocol=protocol, compress=compress)


@pytest.mark.parametrize('target_da', [DocumentArray.empty(100), random_docs(100)])
def test_from_to_protobuf(target_da):
DocumentArray.from_protobuf(target_da.to_protobuf())
2 changes: 1 addition & 1 deletion tests/unit/document/test_porting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tests import random_docs


@pytest.mark.parametrize('protocol', ['protobuf', 0, 1, 2, 3, 4])
@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
def test_to_from_bytes(protocol, compress):
d = Document(embedding=[1, 2, 3, 4, 5], text='hello')
Expand Down