Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
dce8c6b
feat: load from and to csv
Feb 17, 2023
37892a9
fix: from to csv
Feb 20, 2023
e1e90cf
feat: add access path to dict
Feb 20, 2023
7bd7f1f
fix: from to csv
Feb 20, 2023
05e19ba
fix: clean up
Feb 20, 2023
5b0c760
docs: add docstring and update tmpdir in test
Feb 20, 2023
5babdae
fix: merge nested dicts
Feb 20, 2023
d6b29c5
fix: clean up
Feb 20, 2023
9149e57
fix: clean up
Feb 20, 2023
3d980b7
test: update test
Feb 20, 2023
516ffbb
fix: apply samis suggestion from code review
Feb 20, 2023
5fac964
Merge branch 'feat-rewrite-v2' into feat-from-to-csv
Feb 20, 2023
a0d9711
fix: apply suggestions from code review wrt access paths
Feb 21, 2023
c9005b1
fix: apply johannes suggestion
Feb 21, 2023
9fe58f5
fix: apply johannes suggestion
Feb 21, 2023
90867bb
fix: apply suggestions from code review
Feb 21, 2023
1a395da
fix: apply suggestions from code review
Feb 21, 2023
00a9ea7
fix: typos
Feb 21, 2023
e06e533
refactor: move helper functions to helper file
Feb 21, 2023
c8e4cf8
test: fix fixture
Feb 21, 2023
6e3a624
feat: add to and from pandas df for documentarray
Feb 22, 2023
12df134
chore: add pandas to pyproject.toml
Feb 22, 2023
504c13c
docs: update docstring
Feb 22, 2023
4bf8976
fix: mypy
Feb 22, 2023
ca2bf4f
fix: clean up
Feb 22, 2023
81ea38d
Merge remote-tracking branch 'origin/feat-rewrite-v2' into feat-from-…
Feb 22, 2023
5d02b35
fix: apply suggestions from code review
Feb 23, 2023
687d98a
fix: apply suggestion from johannes
Feb 23, 2023
2c02bf4
fix: apply suggestion from johannes
Feb 23, 2023
006a186
fix: apply suggestions from code review
Feb 23, 2023
c68a48c
fix: apply suggestion
Feb 23, 2023
fb74c4f
Merge branch 'feat-rewrite-v2' into feat-from-to-pandas
Feb 23, 2023
9d90053
fix: apply suggestions
Feb 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 117 additions & 22 deletions docarray/array/array/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Expand All @@ -26,14 +26,14 @@

from docarray.base_document import AnyDocument, BaseDocument
from docarray.helper import (
_access_path_to_dict,
_access_path_dict_to_nested_dict,
_all_access_paths_valid,
_dict_to_access_paths,
_update_nested_dicts,
is_access_path_valid,
)
from docarray.utils.compress import _decompress_bytes, _get_compress_ctx

if TYPE_CHECKING:
import pandas as pd

from docarray import DocumentArray
from docarray.proto import DocumentArrayProto
Expand Down Expand Up @@ -330,37 +330,37 @@ def from_csv(
"""
from docarray import DocumentArray

doc_type = cls.document_type
if doc_type == AnyDocument:
if cls.document_type == AnyDocument:
raise TypeError(
'There is no document schema defined. '
'To load from csv, please specify the DocumentArray\'s Document type using `DocumentArray[MyDoc]`.'
'Please specify the DocumentArray\'s Document type using `DocumentArray[MyDoc]`.'
)

doc_type = cls.document_type
da = DocumentArray.__class_getitem__(doc_type)()

with open(file_path, 'r', encoding=encoding) as fp:
rows = csv.DictReader(fp, dialect=dialect)
field_names: Optional[Sequence[Any]] = rows.fieldnames

if field_names is None:
field_names: List[str] = (
[] if rows.fieldnames is None else [str(f) for f in rows.fieldnames]
)
if field_names is None or len(field_names) == 0:
raise TypeError("No field names are given.")

valid = [is_access_path_valid(doc_type, field) for field in field_names]
if not all(valid):
valid_paths = _all_access_paths_valid(
doc_type=doc_type, access_paths=field_names
)
if not all(valid_paths):
raise ValueError(
f'Fields provided in the csv file do not match the schema of the DocumentArray\'s '
f'document type ({doc_type.__name__}): {list(compress(field_names, [not v for v in valid]))}'
f'Column names do not match the schema of the DocumentArray\'s '
f'document type ({cls.document_type.__name__}): '
f'{list(compress(field_names, [not v for v in valid_paths]))}'
)

for access_path2val in rows:
doc_dict: Dict[Any, Any] = {}
for access_path, value in access_path2val.items():
field2val = _access_path_to_dict(
access_path=access_path,
value=value if value not in ['', 'None'] else None,
)
_update_nested_dicts(to_update=doc_dict, update_with=field2val)

doc_dict: Dict[Any, Any] = _access_path_dict_to_nested_dict(
access_path2val
)
da.append(doc_type.parse_obj(doc_dict))

return da
Expand Down Expand Up @@ -392,6 +392,101 @@ def to_csv(
doc_dict = _dict_to_access_paths(doc.dict())
writer.writerow(doc_dict)

@classmethod
def from_pandas(cls, df: 'pd.DataFrame') -> 'DocumentArray':
"""
Load a DocumentArray from a `pandas.DataFrame` following the schema
defined in the :attr:`~docarray.DocumentArray.document_type` attribute.
Every row of the dataframe will be mapped to one Document in the array.
The column names of the dataframe have to match the field names of the
Document type.
For nested fields use "__"-separated access paths as column names,
such as 'image__url'.

List-like fields (including field of type DocumentArray) are not supported.

Comment thread
anna-charlotte marked this conversation as resolved.
EXAMPLE USAGE:

.. code-block:: python

import pandas as pd

from docarray import BaseDocument, DocumentArray


class Person(BaseDocument):
name: str
follower: int


df = pd.DataFrame(
data=[['Maria', 12345], ['Jake', 54321]], columns=['name', 'follower']
)

da = DocumentArray[Person].from_pandas(df)

assert da.name == ['Maria', 'Jake']
assert da.follower == [12345, 54321]


:param df: pandas.DataFrame to extract Document's information from
:return: DocumentArray where each Document contains the information of one
corresponding row of the `pandas.DataFrame`.
"""
from docarray import DocumentArray

if cls.document_type == AnyDocument:
raise TypeError(
'There is no document schema defined. '
'Please specify the DocumentArray\'s Document type using `DocumentArray[MyDoc]`.'
)

doc_type = cls.document_type
da = DocumentArray.__class_getitem__(doc_type)()
field_names = df.columns.tolist()

if field_names is None or len(field_names) == 0:
raise TypeError("No field names are given.")

valid_paths = _all_access_paths_valid(
doc_type=doc_type, access_paths=field_names
)
if not all(valid_paths):
raise ValueError(
f'Column names do not match the schema of the DocumentArray\'s '
f'document type ({cls.document_type.__name__}): '
f'{list(compress(field_names, [not v for v in valid_paths]))}'
)

for row in df.itertuples():
access_path2val = row._asdict()
access_path2val.pop('Index', None)
doc_dict = _access_path_dict_to_nested_dict(access_path2val)
da.append(doc_type.parse_obj(doc_dict))

return da

def to_pandas(self) -> 'pd.DataFrame':
"""
Save a DocumentArray to a `pandas.DataFrame`.
The field names will be stored as column names. Each row of the dataframe corresponds
to the information of one Document.
Columns for nested fields will be named after the "__"-seperated access paths,
such as `"image__url"` for `image.url`.

:return: pandas.DataFrame
"""
import pandas as pd

fields = self.document_type._get_access_paths()
df = pd.DataFrame(columns=fields)

for doc in self:
doc_dict = _dict_to_access_paths(doc.dict())
df = df.append(doc_dict, ignore_index=True)

return df

# Methods to load from/to files in different formats
@property
def _stream_header(self) -> bytes:
Expand Down
47 changes: 41 additions & 6 deletions docarray/helper.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
from typing import TYPE_CHECKING, Any, Dict, Type
from typing import TYPE_CHECKING, Any, Dict, List, Type

if TYPE_CHECKING:
from docarray import BaseDocument


def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool:
def _is_access_path_valid(doc_type: Type['BaseDocument'], access_path: str) -> bool:
"""
Check if a given access path ("__"-separated) is a valid path for a given Document class.
"""
from docarray import BaseDocument

field, _, remaining = access_path.partition('__')
if len(remaining) == 0:
return access_path in doc.__fields__.keys()
return access_path in doc_type.__fields__.keys()
else:
valid_field = field in doc.__fields__.keys()
valid_field = field in doc_type.__fields__.keys()
if not valid_field:
return False
else:
d = doc._get_field_type(field)
d = doc_type._get_field_type(field)
if not issubclass(d, BaseDocument):
return False
else:
return is_access_path_valid(d, remaining)
return _is_access_path_valid(d, remaining)


def _all_access_paths_valid(
doc_type: Type['BaseDocument'], access_paths: List[str]
) -> List[bool]:
"""
Check if all access paths ("__"-separated) are valid for a given Document class.
"""
return [_is_access_path_valid(doc_type, path) for path in access_paths]


def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]:
Expand All @@ -40,6 +49,32 @@ def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]:
return result


def _access_path_dict_to_nested_dict(access_path2val: Dict[str, Any]) -> Dict[Any, Any]:
"""
Convert a dict, where the keys are access paths ("__"-separated) to a nested dictionary.

EXAMPLE USAGE

.. code-block:: python

access_path2val = {'image__url': 'some.png'}
assert access_path_dict_to_nested_dict(access_path2val) == {
'image': {'url': 'some.png'}
}

:param access_path2val: dict with access_paths as keys
:return: nested dict where the access path keys are split into separate field names and nested keys
"""
nested_dict: Dict[Any, Any] = {}
for access_path, value in access_path2val.items():
field2val = _access_path_to_dict(
access_path=access_path,
value=value if value not in ['', 'None'] else None,
)
_update_nested_dicts(to_update=nested_dict, update_with=field2val)
return nested_dict


def _dict_to_access_paths(d: dict) -> Dict[str, Any]:
"""
Convert a (nested) dict to a Dict[access_path, value].
Expand Down
43 changes: 39 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ fastapi = {version = ">=0.87.0", optional = true }
rich = ">=13.1.0"
lz4 = {version= ">=1.0.0", optional = true}
pydub = {version = "^0.25.1", optional = true }
pandas = {version = ">=1.1.0", optional = true }

[tool.poetry.extras]
common = ["protobuf", "lz4"]
Expand All @@ -31,6 +32,7 @@ video = ["av"]
audio = ["pydub"]
mesh = ["trimesh"]
web = ["fastapi"]
pandas = ["pandas"]

[tool.poetry.dev-dependencies]
pytest = ">=6.1"
Expand Down Expand Up @@ -60,6 +62,10 @@ check_untyped_defs = true
module = "av"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "pandas"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "trimesh"
ignore_missing_imports = true
Expand Down
4 changes: 1 addition & 3 deletions tests/units/array/test_array_from_to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ def test_from_csv_without_schema_raise_exception():


def test_from_csv_with_wrong_schema_raise_exception(nested_doc):
with pytest.raises(
ValueError, match='Fields provided in the csv file do not match the schema'
):
with pytest.raises(ValueError, match='Column names do not match the schema'):
DocumentArray[nested_doc.__class__].from_csv(
file_path=str(TOYDATA_DIR / 'docs.csv')
)
Loading