Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions docarray/array/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TypeVar,
Union,
overload,
Dict,
)

from typing_inspect import is_union_type
Expand All @@ -26,6 +27,8 @@
)
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.typing import NdArray
from pydantic.generics import GenericModel


if TYPE_CHECKING:
from pydantic import BaseConfig
Expand All @@ -52,13 +55,13 @@ def _delegate_meth_to_data(meth_name: str) -> Callable:

@wraps(func)
def _delegate_meth(self, *args, **kwargs):
return getattr(self._data, meth_name)(*args, **kwargs)
return getattr(self.data, meth_name)(*args, **kwargs)

return _delegate_meth


class DocArray(
IndexingSequenceMixin[T_doc], PushPullMixin, IOMixinArray, AnyDocArray[T_doc]
GenericModel, IndexingSequenceMixin[T_doc], PushPullMixin, IOMixinArray, AnyDocArray[T_doc]
):
"""
DocArray is a container of Documents.
Expand Down Expand Up @@ -121,13 +124,15 @@ class Image(BaseDoc):

"""

document_type: Type[BaseDoc] = AnyDoc
data: List[T_doc] = []
_document_type: Type[BaseDoc] = AnyDoc

def __init__(
self,
docs: Optional[Iterable[T_doc]] = None,
):
self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else []
super().__init__()
self.data: List[T_doc] = list(self._validate_docs(docs)) if docs else []

@classmethod
def construct(
Expand All @@ -141,7 +146,7 @@ def construct(
:return:
"""
da = cls.__new__(cls)
da._data = docs if isinstance(docs, list) else list(docs)
da.data = docs if isinstance(docs, list) else list(docs)
return da

def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:
Expand All @@ -153,17 +158,17 @@ def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:

def _validate_one_doc(self, doc: T_doc) -> T_doc:
"""Validate if a Document is compatible with this DocArray"""
if not issubclass(self.document_type, AnyDoc) and not isinstance(
doc, self.document_type
if not issubclass(self._document_type, AnyDoc) and not isinstance(
doc, self._document_type
):
raise ValueError(f'{doc} is not a {self.document_type}')
raise ValueError(f'{doc} is not a {self._document_type}')
return doc

def __len__(self):
return len(self._data)
return len(self.data)

def __iter__(self):
return iter(self._data)
return iter(self.data)

def __bytes__(self) -> bytes:
with io.BytesIO() as bf:
Expand All @@ -176,7 +181,7 @@ def append(self, doc: T_doc):
as the document_type of this DocArray otherwise it will fail.
:param doc: A Document
"""
self._data.append(self._validate_one_doc(doc))
self.data.append(self._validate_one_doc(doc))

def extend(self, docs: Iterable[T_doc]):
"""
Expand All @@ -185,7 +190,7 @@ def extend(self, docs: Iterable[T_doc]):
fail.
:param docs: Iterable of Documents
"""
self._data.extend(self._validate_docs(docs))
self.data.extend(self._validate_docs(docs))

def insert(self, i: int, doc: T_doc):
"""
Expand All @@ -194,7 +199,7 @@ class as the document_type of this DocArray otherwise it will fail.
:param i: index to insert
:param doc: A Document
"""
self._data.insert(i, self._validate_one_doc(doc))
self.data.insert(i, self._validate_one_doc(doc))

pop = _delegate_meth_to_data('pop')
remove = _delegate_meth_to_data('remove')
Expand Down Expand Up @@ -270,10 +275,12 @@ def validate(

if isinstance(value, (cls, DocArrayStacked)):
return value
elif isinstance(value, Dict):
return cls([cls._document_type(**v) for v in value['data']])
elif isinstance(value, Iterable):
return cls(value)
else:
raise TypeError(f'Expecting an Iterable of {cls.document_type}')
raise TypeError(f'Expecting an Iterable of {cls._document_type}')

def traverse_flat(
self: 'DocArray',
Expand Down
30 changes: 15 additions & 15 deletions docarray/array/array/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __getitem__(self, item: slice):

class IOMixinArray(Iterable[BaseDoc]):

document_type: Type[BaseDoc]
_document_type: Type[BaseDoc]

@abstractmethod
def __len__(self):
Expand All @@ -113,7 +113,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T:
:param pb_msg: The protobuf message from where to construct the DocArray
"""
return cls(
cls.document_type.from_protobuf(doc_proto) for doc_proto in pb_msg.docs
cls._document_type.from_protobuf(doc_proto) for doc_proto in pb_msg.docs
)

def to_protobuf(self) -> 'DocumentArrayProto':
Expand Down Expand Up @@ -250,7 +250,7 @@ def to_bytes(
:return: the binary serialization in bytes or None if file_ctx is passed where to store
"""

with (file_ctx or io.BytesIO()) as bf:
with file_ctx or io.BytesIO() as bf:
self._write_bytes(
bf=bf,
protocol=protocol,
Expand Down Expand Up @@ -318,7 +318,7 @@ def from_json(
:return: the deserialized DocArray
"""
json_docs = json.loads(file)
return cls([cls.document_type.parse_raw(v) for v in json_docs])
return cls([cls._document_type(**v) for v in json_docs['data']])

def to_json(self) -> str:
"""Convert the object into a JSON string. Can be loaded via :meth:`.from_json`.
Expand All @@ -335,7 +335,7 @@ def from_csv(
) -> 'DocArray':
"""
Load a DocArray from a csv file following the schema defined in the
:attr:`~docarray.DocArray.document_type` attribute.
:attr:`~docarray.DocArray._document_type` attribute.
Every row of the csv file will be mapped to one document in the array.
The column names (defined in the first row) have to match the field names
of the Document type.
Expand All @@ -354,13 +354,13 @@ def from_csv(
"""
from docarray import DocArray

if cls.document_type == AnyDoc:
if cls._document_type == AnyDoc:
raise TypeError(
'There is no document schema defined. '
'Please specify the DocArray\'s Document type using `DocArray[MyDoc]`.'
)

doc_type = cls.document_type
doc_type = cls._document_type
da = DocArray.__class_getitem__(doc_type)()

with open(file_path, 'r', encoding=encoding) as fp:
Expand All @@ -377,7 +377,7 @@ def from_csv(
if not all(valid_paths):
raise ValueError(
f'Column names do not match the schema of the DocArray\'s '
f'document type ({cls.document_type.__name__}): '
f'document type ({cls._document_type.__name__}): '
f'{list(compress(field_names, [not v for v in valid_paths]))}'
)

Expand Down Expand Up @@ -406,7 +406,7 @@ def to_csv(
'excel-tab' (for tab separated values),
'unix' (for csv file generated on UNIX systems).
"""
fields = self.document_type._get_access_paths()
fields = self._document_type._get_access_paths()

with open(file_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=fields, dialect=dialect)
Expand All @@ -420,7 +420,7 @@ def to_csv(
def from_pandas(cls, df: 'pd.DataFrame') -> 'DocArray':
"""
Load a DocArray from a `pandas.DataFrame` following the schema
defined in the :attr:`~docarray.DocArray.document_type` attribute.
defined in the :attr:`~docarray.DocArray._document_type` attribute.
Every row of the dataframe will be mapped to one Document in the array.
The column names of the dataframe have to match the field names of the
Document type.
Expand Down Expand Up @@ -459,13 +459,13 @@ class Person(BaseDoc):
"""
from docarray import DocArray

if cls.document_type == AnyDoc:
if cls._document_type == AnyDoc:
raise TypeError(
'There is no document schema defined. '
'Please specify the DocArray\'s Document type using `DocArray[MyDoc]`.'
)

doc_type = cls.document_type
doc_type = cls._document_type
da = DocArray.__class_getitem__(doc_type)()
field_names = df.columns.tolist()

Expand All @@ -478,7 +478,7 @@ class Person(BaseDoc):
if not all(valid_paths):
raise ValueError(
f'Column names do not match the schema of the DocArray\'s '
f'document type ({cls.document_type.__name__}): '
f'document type ({cls._document_type.__name__}): '
f'{list(compress(field_names, [not v for v in valid_paths]))}'
)

Expand All @@ -502,7 +502,7 @@ def to_pandas(self) -> 'pd.DataFrame':
"""
import pandas as pd

fields = self.document_type._get_access_paths()
fields = self._document_type._get_access_paths()
df = pd.DataFrame(columns=fields)

for doc in self:
Expand Down Expand Up @@ -592,7 +592,7 @@ def _load_binary_all(

# variable length bytes doc
load_protocol: str = protocol or 'protobuf'
doc = cls.document_type.from_bytes(
doc = cls._document_type.from_bytes(
d[start_doc_pos:end_doc_pos],
protocol=load_protocol,
compress=compress,
Expand Down
20 changes: 10 additions & 10 deletions docarray/array/array/sequence_indexing_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class IndexingSequenceMixin(Iterable[T_item]):

"""

_data: MutableSequence[T_item]
data: MutableSequence[T_item]

@abc.abstractmethod
def __init__(
Expand Down Expand Up @@ -102,13 +102,13 @@ def _normalize_index_item(
def _get_from_indices(self: T, item: Iterable[int]) -> T:
results = []
for ix in item:
results.append(self._data[ix])
results.append(self.data[ix])
return self.__class__(results)

def _set_by_indices(self: T, item: Iterable[int], value: Iterable[T_item]):
for ix, doc_to_set in zip(item, value):
try:
self._data[ix] = doc_to_set
self.data[ix] = doc_to_set
except KeyError:
raise IndexError(f'Index {ix} is out of range')

Expand All @@ -121,7 +121,7 @@ def _set_by_mask(self: T, item: Iterable[bool], value: Sequence[T_item]):
i_value = 0
for i, mask_value in zip(range(len(self)), item):
if mask_value:
self._data[i] = value[i_value]
self.data[i] = value[i_value]
i_value += 1

def _del_from_mask(self: T, item: Iterable[bool]) -> None:
Expand All @@ -132,15 +132,15 @@ def _del_from_indices(self: T, item: Iterable[int]) -> None:
for ix in sorted(item, reverse=True):
# reversed is needed here otherwise some the indices are not up to date after
# each delete
del self._data[ix]
del self.data[ix]

def __delitem__(self, key: Union[int, IndexIterType]) -> None:
item = self._normalize_index_item(key)

if item is None:
return
elif isinstance(item, (int, slice)):
del self._data[item]
del self.data[item]
else:
head = item[0] # type: ignore
if isinstance(head, bool):
Expand All @@ -163,10 +163,10 @@ def __getitem__(self, item):
item = self._normalize_index_item(item)

if type(item) == slice:
return self.__class__(self._data[item])
return self.__class__(self.data[item])

if isinstance(item, int):
return self._data[item]
return self.data[item]

if item is None:
return self
Expand All @@ -193,9 +193,9 @@ def __setitem__(self: T, key, value):
key_norm = self._normalize_index_item(key)

if isinstance(key_norm, int):
self._data[key_norm] = value
self.data[key_norm] = value
elif isinstance(key_norm, slice):
self._data[key_norm] = value
self.data[key_norm] = value
else:
# _normalize_index_item() guarantees the line below is correct
head = key_norm[0]
Expand Down
2 changes: 1 addition & 1 deletion docarray/display/document_array_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def summary(self) -> None:
table.add_row(f' • {field_name}:', col_2)

Console().print(Panel(table, title='DocArray Summary', expand=False))
self.da.document_type.schema_summary()
self.da._document_type.schema_summary()

@staticmethod
def _get_stacked_fields(da: 'DocArrayStacked') -> List[str]: # TODO this might
Expand Down
26 changes: 26 additions & 0 deletions tests/integrations/externals/test_documentarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient

from docarray import DocumentArray

from docarray.documents import TextDoc


@pytest.mark.asyncio
async def test_fast_api():
doc = TextDoc(text='some txt')
docs = DocumentArray[TextDoc](docs=[doc])
app = FastAPI()

@app.post("/doc/")
async def func(fastapi_docs: DocumentArray[TextDoc]) -> DocumentArray[TextDoc]:
return fastapi_docs

async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post("/doc/", data=docs.json())

assert response.status_code == 200

returned_docs = DocumentArray[TextDoc].from_json(response.content.decode())
returned_docs.summary()