Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions docarray/array/doc_list/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
_dict_to_access_paths,
)
from docarray.utils._internal.compress import _decompress_bytes, _get_compress_ctx
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.misc import import_library, ProtocolType

if TYPE_CHECKING:
import pandas as pd
Expand All @@ -57,9 +57,9 @@

def _protocol_and_compress_from_file_path(
file_path: Union[pathlib.Path, str],
default_protocol: Optional[str] = None,
default_protocol: Optional[ProtocolType] = None,
default_compress: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
) -> Tuple[Optional[ProtocolType], Optional[str]]:
"""Extract protocol and compression algorithm from a string, use defaults if not found.
:param file_path: path of a file.
:param default_protocol: default serialization protocol used in case not found.
Expand All @@ -79,7 +79,7 @@ def _protocol_and_compress_from_file_path(
file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes]
for extension in file_extensions:
if extension in ALLOWED_PROTOCOLS:
protocol = extension
protocol = cast(ProtocolType, extension)
elif extension in ALLOWED_COMPRESSIONS:
compress = extension

Expand Down Expand Up @@ -135,7 +135,7 @@ def to_protobuf(self) -> 'DocListProto':
def from_bytes(
cls: Type[T],
data: bytes,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> T:
Expand All @@ -157,7 +157,7 @@ def from_bytes(
def _write_bytes(
self,
bf: BinaryIO,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> None:
Expand Down Expand Up @@ -201,7 +201,7 @@ def _write_bytes(

def _to_binary_stream(
self,
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Iterator[bytes]:
Expand Down Expand Up @@ -241,7 +241,7 @@ def _to_binary_stream(

def to_bytes(
self,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
file_ctx: Optional[BinaryIO] = None,
show_progress: bool = False,
Expand Down Expand Up @@ -273,7 +273,7 @@ def to_bytes(
def from_base64(
cls: Type[T],
data: str,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> T:
Expand All @@ -294,7 +294,7 @@ def from_base64(

def to_base64(
self,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> str:
Expand Down Expand Up @@ -383,7 +383,6 @@ def _from_csv_file(
file: Union[StringIO, TextIOWrapper],
dialect: Union[str, csv.Dialect],
) -> 'T':

rows = csv.DictReader(file, dialect=dialect)

doc_type = cls.doc_type
Expand Down Expand Up @@ -576,7 +575,7 @@ def _get_proto_class(cls: Type[T]):
def _load_binary_all(
cls: Type[T],
file_ctx: Union[ContextManager[io.BufferedReader], ContextManager[bytes]],
protocol: Optional[str],
protocol: Optional[ProtocolType],
compress: Optional[str],
show_progress: bool,
tensor_type: Optional[Type['AbstractTensor']] = None,
Expand Down Expand Up @@ -659,7 +658,9 @@ def _load_binary_all(
start_pos = end_doc_pos

# variable length bytes doc
load_protocol: str = protocol or 'protobuf'
load_protocol: ProtocolType = protocol or cast(
ProtocolType, 'protobuf'
)
doc = cls.doc_type.from_bytes(
d[start_doc_pos:end_doc_pos],
protocol=load_protocol,
Expand All @@ -680,7 +681,7 @@ def _load_binary_all(
def _load_binary_stream(
cls: Type[T],
file_ctx: ContextManager[io.BufferedReader],
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Generator['T_doc', None, None]:
Expand Down Expand Up @@ -728,7 +729,7 @@ def _load_binary_stream(
len_current_doc_in_bytes = int.from_bytes(
f.read(4), 'big', signed=False
)
load_protocol: str = protocol
load_protocol: ProtocolType = protocol
yield cls.doc_type.from_bytes(
f.read(len_current_doc_in_bytes),
protocol=load_protocol,
Expand All @@ -743,10 +744,12 @@ def _load_binary_stream(
@staticmethod
def _get_file_context(
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
protocol: str,
protocol: ProtocolType,
compress: Optional[str] = None,
) -> Tuple[Union[nullcontext, io.BufferedReader], Optional[str], Optional[str]]:
load_protocol: Optional[str] = protocol
) -> Tuple[
Union[nullcontext, io.BufferedReader], Optional[ProtocolType], Optional[str]
]:
load_protocol: Optional[ProtocolType] = protocol
load_compress: Optional[str] = compress
file_ctx: Union[nullcontext, io.BufferedReader]
if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)):
Expand All @@ -765,7 +768,7 @@ def _get_file_context(
def load_binary(
cls: Type[T],
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
streaming: bool = False,
Expand Down Expand Up @@ -814,7 +817,7 @@ def load_binary(
def save_binary(
self,
file: Union[str, pathlib.Path],
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions docarray/array/doc_vec/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.pydantic import is_pydantic_v2
from docarray.utils._internal.misc import ProtocolType

if TYPE_CHECKING:
import csv
Expand Down Expand Up @@ -134,7 +135,6 @@ def _from_json_col_dict(
json_columns: Dict[str, Any],
tensor_type: Type[AbstractTensor] = NdArray,
) -> T:

tensor_cols = json_columns['tensor_columns']
doc_cols = json_columns['doc_columns']
docs_vec_cols = json_columns['docs_vec_columns']
Expand Down Expand Up @@ -351,7 +351,7 @@ def from_csv(
def from_base64(
cls: Type[T],
data: str,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
tensor_type: Type['AbstractTensor'] = NdArray,
Expand All @@ -377,7 +377,7 @@ def from_base64(
def from_bytes(
cls: Type[T],
data: bytes,
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
tensor_type: Type['AbstractTensor'] = NdArray,
Expand Down Expand Up @@ -454,7 +454,7 @@ class Person(BaseDoc):
def load_binary(
cls: Type[T],
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
protocol: str = 'protobuf-array',
protocol: ProtocolType = 'protobuf-array',
compress: Optional[str] = None,
show_progress: bool = False,
streaming: bool = False,
Expand Down
11 changes: 4 additions & 7 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS
from docarray.utils._internal._typing import safe_issubclass
from docarray.utils._internal.compress import _compress_bytes, _decompress_bytes
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.misc import ProtocolType, import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if TYPE_CHECKING:
Expand All @@ -37,7 +37,6 @@
from docarray.proto import DocProto, NodeProto
from docarray.typing import TensorFlowTensor, TorchTensor


else:
tf = import_library('tensorflow', raise_error=False)
if tf is not None:
Expand Down Expand Up @@ -150,7 +149,7 @@ def __bytes__(self) -> bytes:
return self.to_bytes()

def to_bytes(
self, protocol: str = 'protobuf', compress: Optional[str] = None
self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None
) -> bytes:
"""Serialize itself into bytes.

Expand All @@ -177,7 +176,7 @@ def to_bytes(
def from_bytes(
cls: Type[T],
data: bytes,
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
) -> T:
"""Build Document object from binary bytes
Expand All @@ -203,7 +202,7 @@ def from_bytes(
)

def to_base64(
self, protocol: str = 'protobuf', compress: Optional[str] = None
self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None
) -> str:
"""Serialize a Document object into as base64 string

Expand Down Expand Up @@ -329,7 +328,6 @@ def _get_content_from_node_proto(
return_field = getattr(value, content_key)

elif content_key in arg_to_container.keys():

if field_name and field_name in cls._docarray_fields():
field_type = cls._get_field_inner_type(field_name)
else:
Expand All @@ -347,7 +345,6 @@ def _get_content_from_node_proto(
deser_dict: Dict[str, Any] = dict()

if field_name and field_name in cls._docarray_fields():

if is_pydantic_v2:
dict_args = get_args(
cls._docarray_fields()[field_name].annotation
Expand Down
9 changes: 5 additions & 4 deletions docarray/store/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rich import filesize
from typing_extensions import TYPE_CHECKING, Protocol

from docarray.utils._internal.misc import ProtocolType
from docarray.utils._internal.progress_bar import _get_progressbar

if TYPE_CHECKING:
Expand Down Expand Up @@ -112,12 +113,12 @@ def raise_req_error(resp: 'requests.Response') -> NoReturn:
class Streamable(Protocol):
"""A protocol for streamable objects."""

def to_bytes(self, protocol: str, compress: Optional[str]) -> bytes:
def to_bytes(self, protocol: ProtocolType, compress: Optional[str]) -> bytes:
...

@classmethod
def from_bytes(
cls: Type[T_Elem], bytes: bytes, protocol: str, compress: Optional[str]
cls: Type[T_Elem], bytes: bytes, protocol: ProtocolType, compress: Optional[str]
) -> 'T_Elem':
...

Expand All @@ -133,7 +134,7 @@ def close(self):
def _to_binary_stream(
iterator: Iterator['Streamable'],
total: Optional[int] = None,
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Iterator[bytes]:
Expand Down Expand Up @@ -170,7 +171,7 @@ def _from_binary_stream(
cls: Type[T],
stream: ReadableBytes,
total: Optional[int] = None,
protocol: str = 'protobuf',
protocol: ProtocolType = 'protobuf',
compress: Optional[str] = None,
show_progress: bool = False,
) -> Iterator['T']:
Expand Down
6 changes: 5 additions & 1 deletion docarray/utils/_internal/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
import types
from typing import Any, Optional
from typing import Any, Optional, Literal

import numpy as np

Expand Down Expand Up @@ -52,6 +52,10 @@
'pymilvus': '"docarray[milvus]"',
}

ProtocolType = Literal[

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed the ALLOWED_PROTOCOLS constant at https://github.com/docarray/docarray/blob/main/docarray/array/doc_list/io.py#L54C1-L54C1

It makes me uncomfortable to define the same information twice, so I can investigate combining the ProtocolType and ALLOWED_PROTOCOLS objects, even though one is a type object and the other is a set of strings.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow-up PR you can define ALLOWED_PROTOCOLS in relation to these types

'protobuf', 'pickle', 'json', 'json-array', 'protobuf-array', 'pickle-array'
]


def import_library(
package: str, raise_error: bool = True
Expand Down