Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docarray/array/abstract_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):

class _DocArrayTyped(cls): # type: ignore
document_type: Type[BaseDoc] = cast(Type[BaseDoc], item)
__origin__ = cls
__args__ = [item]

for field in _DocArrayTyped.document_type.__fields__.keys():

Expand Down
20 changes: 1 addition & 19 deletions docarray/array/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
from docarray.typing import NdArray

if TYPE_CHECKING:
from pydantic import BaseConfig
from pydantic.fields import ModelField

from docarray.array.stacked.array_stacked import DocArrayStacked
from docarray.proto import DocumentArrayProto
from docarray.typing import TorchTensor
Expand Down Expand Up @@ -127,6 +124,7 @@ def __init__(
self,
docs: Optional[Iterable[T_doc]] = None,
):
super().__init__()
self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else []

@classmethod
Expand Down Expand Up @@ -259,22 +257,6 @@ def stack(
self, tensor_type=tensor_type
)

@classmethod
def validate(
cls: Type[T],
value: Union[T, Iterable[BaseDoc]],
field: 'ModelField',
config: 'BaseConfig',
):
from docarray.array.stacked.array_stacked import DocArrayStacked

if isinstance(value, (cls, DocArrayStacked)):
return value
elif isinstance(value, Iterable):
return cls(value)
else:
raise TypeError(f'Expecting an Iterable of {cls.document_type}')

def traverse_flat(
self: 'DocArray',
access_path: str,
Expand Down
8 changes: 4 additions & 4 deletions docarray/array/array/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def __getitem__(self, item: slice):


class IOMixinArray(Iterable[BaseDoc]):

document_type: Type[BaseDoc]

@abstractmethod
Expand Down Expand Up @@ -250,7 +249,7 @@ def to_bytes(
:return: the binary serialization in bytes or None if file_ctx is passed where to store
"""

with (file_ctx or io.BytesIO()) as bf:
with file_ctx or io.BytesIO() as bf:
self._write_bytes(
bf=bf,
protocol=protocol,
Expand Down Expand Up @@ -318,13 +317,14 @@ def from_json(
:return: the deserialized DocArray
"""
json_docs = json.loads(file)
return cls([cls.document_type.parse_raw(v) for v in json_docs])
return cls([cls.document_type(**v) for v in json_docs])

def to_json(self) -> str:
"""Convert the object into a JSON string. Can be loaded via :meth:`.from_json`.
:return: JSON serialization of DocArray
"""
return json.dumps([doc.json() for doc in self])
doc_jsons = ', '.join([doc.json() for doc in self])
return f'[{doc_jsons}]'

@classmethod
def from_csv(
Expand Down
5 changes: 3 additions & 2 deletions docarray/array/array/sequence_indexing_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
cast,
no_type_check,
overload,
List,
)

import numpy as np
Expand All @@ -33,9 +34,9 @@ def _is_np_int(item: Any) -> bool:
return False # this is unreachable, but mypy wants it


class IndexingSequenceMixin(Iterable[T_item]):
class IndexingSequenceMixin(List[T_item]):
"""
This mixin allow sto extend a list into an object that can be indexed
This mixin allows to extend a list into an object that can be indexed
a la numpy/pytorch.

You can index into, delete from, and set items in a IndexingSequenceMixin like a numpy array or torch tensor:
Expand Down
3 changes: 2 additions & 1 deletion docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from docarray.base_doc.io.json import orjson_dumps, orjson_dumps_and_decode
from docarray.base_doc.mixins import IOMixin, UpdateMixin
from docarray.typing import ID
from docarray.typing.tensor.abstract_tensor import AbstractTensor

if TYPE_CHECKING:
from docarray.array.stacked.column_storage import ColumnStorageView
Expand All @@ -28,7 +29,7 @@ class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode):
class Config:
json_loads = orjson.loads
json_dumps = orjson_dumps_and_decode
json_encoders = {dict: orjson_dumps}
json_encoders = {AbstractTensor: orjson_dumps}

validate_assignment = True

Expand Down
41 changes: 24 additions & 17 deletions docarray/typing/abstract_type.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,32 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar

from pydantic import BaseConfig
from pydantic.fields import ModelField

from docarray.base_doc.base_node import BaseNode

if TYPE_CHECKING:
from docarray.proto import NodeProto

T = TypeVar('T')


class AbstractType(BaseNode):
class AbstractType(ABC):
_proto_type_name: Optional[str] = None

@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
@abstractmethod
def validate(
cls: Type[T],
value: Any,
field: 'ModelField',
config: 'BaseConfig',
) -> T:
...

@classmethod
@abstractmethod
def from_protobuf(cls: Type[T], pb_msg: T) -> T:
...

@abstractmethod
def _to_node_protobuf(self: T) -> 'NodeProto':
"""Convert itself into a NodeProto message. This function should
be called when the self is nested into another Document that need to be
converted into a protobuf

:return: the nested item protobuf message
"""
...

def _docarray_to_json_compatible(self):
Expand All @@ -44,3 +35,19 @@ def _docarray_to_json_compatible(self):
:return: a representation of the tensor compatible with orjson
"""
return self


class AbstractValidator(ABC):
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
@abstractmethod
def validate(
cls: Type[T],
value: Any,
field: 'ModelField',
config: 'BaseConfig',
) -> T:
...
26 changes: 25 additions & 1 deletion tests/integrations/externals/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import List

import numpy as np
import pytest
from fastapi import FastAPI
from httpx import AsyncClient

from docarray import BaseDoc
from docarray import BaseDoc, DocArray
from docarray.base_doc import DocResponse
from docarray.documents import ImageDoc, TextDoc
from docarray.typing import NdArray
Expand Down Expand Up @@ -107,3 +109,25 @@ async def create_item(doc: InputDoc) -> OutputDoc:
assert isinstance(doc, OutputDoc)
assert doc.embedding_clip.shape == (100, 1)
assert doc.embedding_bert.shape == (100, 1)


@pytest.mark.asyncio
async def test_docarray():
doc = TextDoc(text='some txt')
docs = DocArray[TextDoc]([doc])

app = FastAPI()

@app.post("/doc/")
async def func(fastapi_docs: List[TextDoc]) -> DocArray[TextDoc]:
fastapi_docs = DocArray[TextDoc].construct(fastapi_docs)
return fastapi_docs

async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post("/doc/", data=docs.to_json())

assert response.status_code == 200

docs = DocArray[TextDoc].from_json(response.content.decode())
assert docs and len(docs) == 1
assert docs[0].text == 'some txt'