Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 103 additions & 2 deletions docarray/array/array/io.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,45 @@
import base64
import csv
import io
import json
import os
import pathlib
import pickle
from abc import abstractmethod
from contextlib import nullcontext
from itertools import compress
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
ContextManager,
Dict,
Generator,
Iterable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

from docarray.base_document import BaseDocument
from docarray.base_document import AnyDocument, BaseDocument
from docarray.helper import (
_access_path_to_dict,
_dict_to_access_paths,
_update_nested_dicts,
is_access_path_valid,
)
from docarray.utils.compress import _decompress_bytes, _get_compress_ctx

if TYPE_CHECKING:

from docarray import DocumentArray
from docarray.proto import DocumentArrayProto

T = TypeVar('T', bound='IOMixinArray')


ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array'}
SINGLE_PROTOCOLS = {'pickle', 'protobuf'}
ALLOWED_PROTOCOLS = ARRAY_PROTOCOLS.union(SINGLE_PROTOCOLS)
Expand Down Expand Up @@ -291,6 +302,96 @@ def to_json(self) -> str:
"""
return json.dumps([doc.json() for doc in self])

@classmethod
def from_csv(
cls,
file_path: str,
encoding: str = 'utf-8',
dialect: Union[str, csv.Dialect] = 'excel',
) -> 'DocumentArray':
"""
Load a DocumentArray from a csv file following the schema defined in the
:attr:`~docarray.DocumentArray.document_type` attribute.
Every row of the csv file will be mapped to one document in the array.
The column names (defined in the first row) have to match the field names
of the Document type.
For nested fields use "__"-separated access paths, such as 'image__url'.

List-like fields (including field of type DocumentArray) are not supported.

Comment thread
anna-charlotte marked this conversation as resolved.
:param file_path: path to csv file to load DocumentArray from.
:param encoding: encoding used to read the csv file. Defaults to 'utf-8'.
:param dialect: defines separator and how to handle whitespaces etc.
Can be a csv.Dialect instance or one string of:
'excel' (for comma seperated values),
'excel-tab' (for tab separated values),
'unix' (for csv file generated on UNIX systems).
:return: DocumentArray
"""
from docarray import DocumentArray

doc_type = cls.document_type
if doc_type == AnyDocument:
raise TypeError(
'There is no document schema defined. '
'To load from csv, please specify the DocumentArray\'s Document type using `DocumentArray[MyDoc]`.'
)

da = DocumentArray.__class_getitem__(doc_type)()
with open(file_path, 'r', encoding=encoding) as fp:
rows = csv.DictReader(fp, dialect=dialect)
field_names: Optional[Sequence[Any]] = rows.fieldnames

if field_names is None:
raise TypeError("No field names are given.")

valid = [is_access_path_valid(doc_type, field) for field in field_names]
if not all(valid):
raise ValueError(
f'Fields provided in the csv file do not match the schema of the DocumentArray\'s '
f'document type ({doc_type.__name__}): {list(compress(field_names, [not v for v in valid]))}'
)

for access_path2val in rows:
doc_dict: Dict[Any, Any] = {}
for access_path, value in access_path2val.items():
field2val = _access_path_to_dict(
access_path=access_path,
value=value if value not in ['', 'None'] else None,
)
_update_nested_dicts(to_update=doc_dict, update_with=field2val)

da.append(doc_type.parse_obj(doc_dict))

return da

def to_csv(
self, file_path: str, dialect: Union[str, csv.Dialect] = 'excel'
) -> None:
"""
Save a DocumentArray to a csv file.
The field names will be stored in the first row. Each row corresponds to the
information of one Document.
Columns for nested fields will be named after the "__"-seperated access paths,
such as `"image__url"` for `image.url`.

:param file_path: path to a csv file.
:param dialect: defines separator and how to handle whitespaces etc.
Can be a csv.Dialect instance or one string of:
'excel' (for comma seperated values),
'excel-tab' (for tab separated values),
'unix' (for csv file generated on UNIX systems).
"""
fields = self.document_type._get_access_paths()

with open(file_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=fields, dialect=dialect)
writer.writeheader()

for doc in self:
doc_dict = _dict_to_access_paths(doc.dict())
writer.writerow(doc_dict)

# Methods to load from/to files in different formats
@property
def _stream_header(self) -> bytes:
Expand Down
23 changes: 23 additions & 0 deletions docarray/base_document/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
)

from typing_inspect import is_union_type

from docarray.base_document.base_node import BaseNode
from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS
from docarray.utils.compress import _compress_bytes, _decompress_bytes
Expand Down Expand Up @@ -291,3 +294,23 @@ def _to_node_protobuf(self) -> 'NodeProto':
:return: the nested item protobuf message
"""
return NodeProto(document=self.to_protobuf())

@classmethod
def _get_access_paths(cls) -> List[str]:
"""
Get "__"-separated access paths of all fields, including nested ones.

:return: list of all access paths
"""
from docarray import BaseDocument

paths = []
for field in cls.__fields__.keys():
field_type = cls._get_field_type(field)
if not is_union_type(field_type) and issubclass(field_type, BaseDocument):
sub_paths = field_type._get_access_paths()
for path in sub_paths:
paths.append(f'{field}__{path}')
else:
paths.append(field)
return paths
88 changes: 88 additions & 0 deletions docarray/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, Any, Dict, Type

if TYPE_CHECKING:
from docarray import BaseDocument


def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool:
"""
Check if a given access path ("__"-separated) is a valid path for a given Document class.
"""
from docarray import BaseDocument

field, _, remaining = access_path.partition('__')
if len(remaining) == 0:
return access_path in doc.__fields__.keys()
else:
valid_field = field in doc.__fields__.keys()
if not valid_field:
return False
else:
d = doc._get_field_type(field)
if not issubclass(d, BaseDocument):
return False
else:
return is_access_path_valid(d, remaining)


def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]:
"""
Convert an access path ("__"-separated) and its value to a (potentially) nested dict.

EXAMPLE USAGE
.. code-block:: python
assert access_path_to_dict('image__url', 'img.png') == {'image': {'url': 'img.png'}}
"""
fields = access_path.split('__')
for field in reversed(fields):
result = {field: value}
value = result
return result


def _dict_to_access_paths(d: dict) -> Dict[str, Any]:
"""
Convert a (nested) dict to a Dict[access_path, value].
Access paths are defined as a path of field(s) separated by "__".

EXAMPLE USAGE
.. code-block:: python
assert dict_to_access_paths({'image': {'url': 'img.png'}}) == {'image__url', 'img.png'}
"""
result = {}
for k, v in d.items():
if isinstance(v, dict):
v = _dict_to_access_paths(v)
for nested_k, nested_v in v.items():
new_key = '__'.join([k, nested_k])
result[new_key] = nested_v
else:
result[k] = v
return result


def _update_nested_dicts(
to_update: Dict[Any, Any], update_with: Dict[Any, Any]
) -> None:
"""
Update a dict with another one, while considering shared nested keys.

EXAMPLE USAGE:

.. code-block:: python

d1 = {'image': {'tensor': None}, 'title': 'hello'}
d2 = {'image': {'url': 'some.png'}}

update_nested_dicts(d1, d2)
assert d1 == {'image': {'tensor': None, 'url': 'some.png'}, 'title': 'hello'}

:param to_update: dict that should be updated
:param update_with: dict to update with
:return: merged dict
"""
for k, v in update_with.items():
if k not in to_update.keys():
to_update[k] = v
else:
_update_nested_dicts(to_update[k], update_with[k])
4 changes: 4 additions & 0 deletions tests/toydata/docs_nested.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
count,text,image,image2__url
000,hello 0,image_0.png,image_10.png
111,hello 1,image_1.png,None
222,hello 2,image_2.png,
101 changes: 101 additions & 0 deletions tests/units/array/test_array_from_to_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
from typing import Optional

import pytest

from docarray import BaseDocument, DocumentArray
from docarray.documents import Image
from tests import TOYDATA_DIR


@pytest.fixture()
def nested_doc_cls():
class MyDoc(BaseDocument):
count: Optional[int]
text: str

class MyDocNested(MyDoc):
image: Image
image2: Image

return MyDocNested


def test_to_from_csv(tmpdir, nested_doc_cls):
da = DocumentArray[nested_doc_cls](
[
nested_doc_cls(
count=0,
text='hello',
image=Image(url='aux.png'),
image2=Image(url='aux.png'),
),
nested_doc_cls(text='hello world', image=Image(), image2=Image()),
]
)
tmp_file = str(tmpdir / 'tmp.csv')
da.to_csv(tmp_file)
assert os.path.isfile(tmp_file)

da_from = DocumentArray[nested_doc_cls].from_csv(tmp_file)
for doc1, doc2 in zip(da, da_from):
assert doc1 == doc2


def test_from_csv_nested(nested_doc_cls):
da = DocumentArray[nested_doc_cls].from_csv(
file_path=str(TOYDATA_DIR / 'docs_nested.csv')
)
assert len(da) == 3

for i, doc in enumerate(da):
assert doc.count.__class__ == int
assert doc.count == int(f'{i}{i}{i}')

assert doc.text.__class__ == str
assert doc.text == f'hello {i}'

assert doc.image.__class__ == Image
assert doc.image.tensor is None
assert doc.image.embedding is None
assert doc.image.bytes is None

assert doc.image2.__class__ == Image
assert doc.image2.tensor is None
assert doc.image2.embedding is None
assert doc.image2.bytes is None

assert da[0].image2.url == 'image_10.png'
assert da[1].image2.url is None
assert da[2].image2.url is None


@pytest.fixture()
def nested_doc():
class Inner(BaseDocument):
img: Optional[Image]

class Middle(BaseDocument):
img: Optional[Image]
inner: Optional[Inner]

class Outer(BaseDocument):
img: Optional[Image]
middle: Optional[Middle]

doc = Outer(img=Image(), middle=Middle(img=Image(), inner=Inner(img=Image())))
return doc


def test_from_csv_without_schema_raise_exception():
with pytest.raises(TypeError, match='no document schema defined'):
DocumentArray.from_csv(file_path=str(TOYDATA_DIR / 'docs_nested.csv'))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense that this raise an Expection for now. But what about implementing at this level the auto schema detection ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for now the exception, I think the auto detection or handing over a schema some other way then setting .document_type I would handle in a separate PR, if that's good with you.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes makes sense to do it in another PR



def test_from_csv_with_wrong_schema_raise_exception(nested_doc):
with pytest.raises(
ValueError, match='Fields provided in the csv file do not match the schema'
):
DocumentArray[nested_doc.__class__].from_csv(
file_path=str(TOYDATA_DIR / 'docs.csv')
)
7 changes: 2 additions & 5 deletions tests/units/array/test_array_from_to_json.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import pytest

from docarray import BaseDocument
from docarray.typing import NdArray
from docarray import BaseDocument, DocumentArray
from docarray.documents import Image
from docarray import DocumentArray
from docarray.typing import NdArray


class MyDoc(BaseDocument):
Expand Down
Loading