Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docarray/array/abstract_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import numpy as np

from docarray.base_doc.base_node import BaseNode
from docarray.base_doc import BaseDoc
from docarray.display.document_array_summary import DocArraySummary
from docarray.typing.abstract_type import AbstractType
from docarray.utils._internal._typing import change_cls_name

if TYPE_CHECKING:
Expand All @@ -33,8 +33,8 @@
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]


class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType):
document_type: Type[BaseDoc]
class AnyDocArray(Sequence[T_doc], Generic[T_doc], BaseNode):
_document_type: Type[BaseDoc]
__typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDoc], Type]] = {}

def __repr__(self):
Expand All @@ -58,9 +58,9 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
global _DocArrayTyped

class _DocArrayTyped(cls): # type: ignore
document_type: Type[BaseDoc] = cast(Type[BaseDoc], item)
_document_type: Type[BaseDoc] = cast(Type[BaseDoc], item)

for field in _DocArrayTyped.document_type.__fields__.keys():
for field in _DocArrayTyped._document_type.__fields__.keys():

def _property_generator(val: str):
def _getter(self):
Expand Down
97 changes: 47 additions & 50 deletions docarray/array/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
List,
MutableSequence,
Optional,
Sequence,
Type,
TypeVar,
Union,
overload,
Dict,
)

import orjson
from typing_inspect import is_union_type

from docarray.array.abstract_array import AnyDocArray
Expand All @@ -25,16 +26,15 @@
IndexIterType,
)
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.base_doc.io.json import orjson_dumps, orjson_dumps_and_decode
from docarray.typing import NdArray
from pydantic import BaseModel
from docarray.typing.tensor.abstract_tensor import AbstractTensor

if TYPE_CHECKING:
from pydantic import BaseConfig
from pydantic.fields import ModelField

from docarray.array.stacked.array_stacked import DocArrayStacked
from docarray.proto import DocumentArrayProto
from docarray.typing import TorchTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor

T = TypeVar('T', bound='DocArray')
T_doc = TypeVar('T_doc', bound=BaseDoc)
Expand All @@ -52,13 +52,17 @@ def _delegate_meth_to_data(meth_name: str) -> Callable:

@wraps(func)
def _delegate_meth(self, *args, **kwargs):
return getattr(self._data, meth_name)(*args, **kwargs)
return getattr(self.data, meth_name)(*args, **kwargs)

return _delegate_meth


class DocArray(
IndexingSequenceMixin[T_doc], PushPullMixin, IOMixinArray, AnyDocArray[T_doc]
BaseModel,
PushPullMixin,
IndexingSequenceMixin[T_doc],
IOMixinArray,
AnyDocArray[T_doc],
):
"""
DocArray is a container of Documents.
Expand Down Expand Up @@ -121,28 +125,22 @@ class Image(BaseDoc):

"""

document_type: Type[BaseDoc] = AnyDoc
data: List[T_doc] = []
_document_type: Type[BaseDoc] = AnyDoc

class Config:
json_loads = orjson.loads
json_dumps = orjson_dumps_and_decode
json_encoders = {AbstractTensor: orjson_dumps}

validate_assignment = True

def __init__(
self,
docs: Optional[Iterable[T_doc]] = None,
data: Optional[Iterable[T_doc]] = None,
):
self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else []

@classmethod
def construct(
cls: Type[T],
docs: Sequence[T_doc],
) -> T:
"""
Create a DocArray without validation any data. The data must come from a
trusted source
:param docs: a Sequence (list) of Document with the same schema
:return:
"""
da = cls.__new__(cls)
da._data = docs if isinstance(docs, list) else list(docs)
return da
super().__init__()
self.data: List[T_doc] = list(self._validate_docs(data)) if data else []

def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:
"""
Expand All @@ -153,17 +151,19 @@ def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:

def _validate_one_doc(self, doc: T_doc) -> T_doc:
"""Validate if a Document is compatible with this DocArray"""
if not issubclass(self.document_type, AnyDoc) and not isinstance(
doc, self.document_type
if isinstance(doc, Dict):
return self._document_type(**doc)
elif not issubclass(self._document_type, AnyDoc) and not isinstance(
doc, self._document_type
):
raise ValueError(f'{doc} is not a {self.document_type}')
raise ValueError(f'{doc} is not a {self._document_type}')
return doc

def __len__(self):
return len(self._data)
return len(self.data)

def __iter__(self):
return iter(self._data)
return iter(self.data)

def __bytes__(self) -> bytes:
with io.BytesIO() as bf:
Expand All @@ -176,7 +176,7 @@ def append(self, doc: T_doc):
as the document_type of this DocArray otherwise it will fail.
:param doc: A Document
"""
self._data.append(self._validate_one_doc(doc))
self.data.append(self._validate_one_doc(doc))

def extend(self, docs: Iterable[T_doc]):
"""
Expand All @@ -185,7 +185,7 @@ def extend(self, docs: Iterable[T_doc]):
fail.
:param docs: Iterable of Documents
"""
self._data.extend(self._validate_docs(docs))
self.data.extend(self._validate_docs(docs))

def insert(self, i: int, doc: T_doc):
"""
Expand All @@ -194,7 +194,7 @@ class as the document_type of this DocArray otherwise it will fail.
:param i: index to insert
:param doc: A Document
"""
self._data.insert(i, self._validate_one_doc(doc))
self.data.insert(i, self._validate_one_doc(doc))

pop = _delegate_meth_to_data('pop')
remove = _delegate_meth_to_data('remove')
Expand All @@ -211,7 +211,7 @@ def _get_data_column(
:return: Returns a list of the field value for each document
in the array like container
"""
field_type = self.__class__.document_type._get_field_type(field)
field_type = self.__class__._document_type._get_field_type(field)

if (
not is_union_type(field_type)
Expand Down Expand Up @@ -255,25 +255,22 @@ def stack(
"""
from docarray.array.stacked.array_stacked import DocArrayStacked

return DocArrayStacked.__class_getitem__(self.document_type)(
return DocArrayStacked.__class_getitem__(self._document_type)(
self, tensor_type=tensor_type
)

@classmethod
def validate(
cls: Type[T],
value: Union[T, Iterable[BaseDoc]],
field: 'ModelField',
config: 'BaseConfig',
):
from docarray.array.stacked.array_stacked import DocArrayStacked

if isinstance(value, (cls, DocArrayStacked)):
return value
elif isinstance(value, Iterable):
return cls(value)
else:
raise TypeError(f'Expecting an Iterable of {cls.document_type}')
# @classmethod
# def validate(cls, value: Any) -> 'DocArray[T_doc]':
# from docarray.array.stacked.array_stacked import DocArrayStacked
#
# if isinstance(value, (cls, DocArrayStacked)):
# return value
# elif isinstance(value, Dict):
# return cls([cls._document_type(**v) for v in value['data']])
# elif isinstance(value, Iterable):
# return cls(value)
# else:
# raise TypeError(f'Expecting an Iterable of {cls._document_type}')

def traverse_flat(
self: 'DocArray',
Expand Down
37 changes: 18 additions & 19 deletions docarray/array/array/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,20 @@ def _protocol_and_compress_from_file_path(

class _LazyRequestReader:
def __init__(self, r):
self._data = r.iter_content(chunk_size=1024 * 1024)
self.data = r.iter_content(chunk_size=1024 * 1024)
self.content = b''

def __getitem__(self, item: slice):
while len(self.content) < item.stop:
try:
self.content += next(self._data)
self.content += next(self.data)
except StopIteration:
return self.content[item.start : -1 : item.step]
return self.content[item]


class IOMixinArray(Iterable[BaseDoc]):

document_type: Type[BaseDoc]
_document_type: Type[BaseDoc]

@abstractmethod
def __len__(self):
Expand All @@ -113,7 +112,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T:
:param pb_msg: The protobuf message from where to construct the DocArray
"""
return cls(
cls.document_type.from_protobuf(doc_proto) for doc_proto in pb_msg.docs
cls._document_type.from_protobuf(doc_proto) for doc_proto in pb_msg.docs
)

def to_protobuf(self) -> 'DocumentArrayProto':
Expand Down Expand Up @@ -250,7 +249,7 @@ def to_bytes(
:return: the binary serialization in bytes or None if file_ctx is passed where to store
"""

with (file_ctx or io.BytesIO()) as bf:
with file_ctx or io.BytesIO() as bf:
self._write_bytes(
bf=bf,
protocol=protocol,
Expand Down Expand Up @@ -318,7 +317,7 @@ def from_json(
:return: the deserialized DocArray
"""
json_docs = json.loads(file)
return cls([cls.document_type.parse_raw(v) for v in json_docs])
return cls([cls._document_type(**v) for v in json_docs['data']])

def to_json(self) -> str:
"""Convert the object into a JSON string. Can be loaded via :meth:`.from_json`.
Expand All @@ -335,7 +334,7 @@ def from_csv(
) -> 'DocArray':
"""
Load a DocArray from a csv file following the schema defined in the
:attr:`~docarray.DocArray.document_type` attribute.
:attr:`~docarray.DocArray._document_type` attribute.
Every row of the csv file will be mapped to one document in the array.
The column names (defined in the first row) have to match the field names
of the Document type.
Expand All @@ -354,13 +353,13 @@ def from_csv(
"""
from docarray import DocArray

if cls.document_type == AnyDoc:
if cls._document_type == AnyDoc:
raise TypeError(
'There is no document schema defined. '
'Please specify the DocArray\'s Document type using `DocArray[MyDoc]`.'
)

doc_type = cls.document_type
doc_type = cls._document_type
da = DocArray.__class_getitem__(doc_type)()

with open(file_path, 'r', encoding=encoding) as fp:
Expand All @@ -377,7 +376,7 @@ def from_csv(
if not all(valid_paths):
raise ValueError(
f'Column names do not match the schema of the DocArray\'s '
f'document type ({cls.document_type.__name__}): '
f'document type ({cls._document_type.__name__}): '
f'{list(compress(field_names, [not v for v in valid_paths]))}'
)

Expand Down Expand Up @@ -406,7 +405,7 @@ def to_csv(
'excel-tab' (for tab separated values),
'unix' (for csv file generated on UNIX systems).
"""
fields = self.document_type._get_access_paths()
fields = self._document_type._get_access_paths()

with open(file_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=fields, dialect=dialect)
Expand All @@ -420,7 +419,7 @@ def to_csv(
def from_pandas(cls, df: 'pd.DataFrame') -> 'DocArray':
"""
Load a DocArray from a `pandas.DataFrame` following the schema
defined in the :attr:`~docarray.DocArray.document_type` attribute.
defined in the :attr:`~docarray.DocArray._document_type` attribute.
Every row of the dataframe will be mapped to one Document in the array.
The column names of the dataframe have to match the field names of the
Document type.
Expand Down Expand Up @@ -459,13 +458,13 @@ class Person(BaseDoc):
"""
from docarray import DocArray

if cls.document_type == AnyDoc:
if cls._document_type == AnyDoc:
raise TypeError(
'There is no document schema defined. '
'Please specify the DocArray\'s Document type using `DocArray[MyDoc]`.'
)

doc_type = cls.document_type
doc_type = cls._document_type
da = DocArray.__class_getitem__(doc_type)()
field_names = df.columns.tolist()

Expand All @@ -478,7 +477,7 @@ class Person(BaseDoc):
if not all(valid_paths):
raise ValueError(
f'Column names do not match the schema of the DocArray\'s '
f'document type ({cls.document_type.__name__}): '
f'document type ({cls._document_type.__name__}): '
f'{list(compress(field_names, [not v for v in valid_paths]))}'
)

Expand All @@ -502,7 +501,7 @@ def to_pandas(self) -> 'pd.DataFrame':
"""
import pandas as pd

fields = self.document_type._get_access_paths()
fields = self._document_type._get_access_paths()
df = pd.DataFrame(columns=fields)

for doc in self:
Expand Down Expand Up @@ -592,7 +591,7 @@ def _load_binary_all(

# variable length bytes doc
load_protocol: str = protocol or 'protobuf'
doc = cls.document_type.from_bytes(
doc = cls._document_type.from_bytes(
d[start_doc_pos:end_doc_pos],
protocol=load_protocol,
compress=compress,
Expand Down Expand Up @@ -649,7 +648,7 @@ def _load_binary_stream(
f.read(4), 'big', signed=False
)
load_protocol: str = protocol
yield cls.document_type.from_bytes(
yield cls._document_type.from_bytes(
f.read(len_current_doc_in_bytes),
protocol=load_protocol,
compress=compress,
Expand Down
Loading