-
Notifications
You must be signed in to change notification settings - Fork 244
feat(v2): load da from csv and save to csv #1144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dce8c6b
37892a9
e1e90cf
7bd7f1f
05e19ba
5b0c760
5babdae
d6b29c5
9149e57
3d980b7
516ffbb
5fac964
a0d9711
c9005b1
9fe58f5
90867bb
1a395da
00a9ea7
e06e533
c8e4cf8
d0a3141
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| from typing import TYPE_CHECKING, Any, Dict, Type | ||
|
|
||
| if TYPE_CHECKING: | ||
| from docarray import BaseDocument | ||
|
|
||
|
|
||
| def is_access_path_valid(doc: Type['BaseDocument'], access_path: str) -> bool: | ||
| """ | ||
| Check if a given access path ("__"-separated) is a valid path for a given Document class. | ||
| """ | ||
| from docarray import BaseDocument | ||
|
|
||
| field, _, remaining = access_path.partition('__') | ||
| if len(remaining) == 0: | ||
| return access_path in doc.__fields__.keys() | ||
| else: | ||
| valid_field = field in doc.__fields__.keys() | ||
| if not valid_field: | ||
| return False | ||
| else: | ||
| d = doc._get_field_type(field) | ||
| if not issubclass(d, BaseDocument): | ||
| return False | ||
| else: | ||
| return is_access_path_valid(d, remaining) | ||
|
|
||
|
|
||
| def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]: | ||
| """ | ||
| Convert an access path ("__"-separated) and its value to a (potentially) nested dict. | ||
|
|
||
| EXAMPLE USAGE | ||
| .. code-block:: python | ||
| assert access_path_to_dict('image__url', 'img.png') == {'image': {'url': 'img.png'}} | ||
| """ | ||
| fields = access_path.split('__') | ||
| for field in reversed(fields): | ||
| result = {field: value} | ||
| value = result | ||
| return result | ||
|
|
||
|
|
||
| def _dict_to_access_paths(d: dict) -> Dict[str, Any]: | ||
| """ | ||
| Convert a (nested) dict to a Dict[access_path, value]. | ||
| Access paths are defined as a path of field(s) separated by "__". | ||
|
|
||
| EXAMPLE USAGE | ||
| .. code-block:: python | ||
| assert dict_to_access_paths({'image': {'url': 'img.png'}}) == {'image__url', 'img.png'} | ||
| """ | ||
| result = {} | ||
| for k, v in d.items(): | ||
| if isinstance(v, dict): | ||
| v = _dict_to_access_paths(v) | ||
| for nested_k, nested_v in v.items(): | ||
| new_key = '__'.join([k, nested_k]) | ||
| result[new_key] = nested_v | ||
| else: | ||
| result[k] = v | ||
| return result | ||
|
|
||
|
|
||
| def _update_nested_dicts( | ||
| to_update: Dict[Any, Any], update_with: Dict[Any, Any] | ||
| ) -> None: | ||
| """ | ||
| Update a dict with another one, while considering shared nested keys. | ||
|
|
||
| EXAMPLE USAGE: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| d1 = {'image': {'tensor': None}, 'title': 'hello'} | ||
| d2 = {'image': {'url': 'some.png'}} | ||
|
|
||
| update_nested_dicts(d1, d2) | ||
| assert d1 == {'image': {'tensor': None, 'url': 'some.png'}, 'title': 'hello'} | ||
|
|
||
| :param to_update: dict that should be updated | ||
| :param update_with: dict to update with | ||
| :return: merged dict | ||
| """ | ||
| for k, v in update_with.items(): | ||
| if k not in to_update.keys(): | ||
| to_update[k] = v | ||
| else: | ||
| _update_nested_dicts(to_update[k], update_with[k]) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| count,text,image,image2__url | ||
| 000,hello 0,image_0.png,image_10.png | ||
| 111,hello 1,image_1.png,None | ||
| 222,hello 2,image_2.png, |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| import os | ||
| from typing import Optional | ||
|
|
||
| import pytest | ||
|
|
||
| from docarray import BaseDocument, DocumentArray | ||
| from docarray.documents import Image | ||
| from tests import TOYDATA_DIR | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def nested_doc_cls(): | ||
| class MyDoc(BaseDocument): | ||
| count: Optional[int] | ||
| text: str | ||
|
|
||
| class MyDocNested(MyDoc): | ||
| image: Image | ||
| image2: Image | ||
|
|
||
| return MyDocNested | ||
|
|
||
|
|
||
| def test_to_from_csv(tmpdir, nested_doc_cls): | ||
| da = DocumentArray[nested_doc_cls]( | ||
| [ | ||
| nested_doc_cls( | ||
| count=0, | ||
| text='hello', | ||
| image=Image(url='aux.png'), | ||
| image2=Image(url='aux.png'), | ||
| ), | ||
| nested_doc_cls(text='hello world', image=Image(), image2=Image()), | ||
| ] | ||
| ) | ||
| tmp_file = str(tmpdir / 'tmp.csv') | ||
| da.to_csv(tmp_file) | ||
| assert os.path.isfile(tmp_file) | ||
|
|
||
| da_from = DocumentArray[nested_doc_cls].from_csv(tmp_file) | ||
| for doc1, doc2 in zip(da, da_from): | ||
| assert doc1 == doc2 | ||
|
|
||
|
|
||
| def test_from_csv_nested(nested_doc_cls): | ||
| da = DocumentArray[nested_doc_cls].from_csv( | ||
| file_path=str(TOYDATA_DIR / 'docs_nested.csv') | ||
| ) | ||
| assert len(da) == 3 | ||
|
|
||
| for i, doc in enumerate(da): | ||
| assert doc.count.__class__ == int | ||
| assert doc.count == int(f'{i}{i}{i}') | ||
|
|
||
| assert doc.text.__class__ == str | ||
| assert doc.text == f'hello {i}' | ||
|
|
||
| assert doc.image.__class__ == Image | ||
| assert doc.image.tensor is None | ||
| assert doc.image.embedding is None | ||
| assert doc.image.bytes is None | ||
|
|
||
| assert doc.image2.__class__ == Image | ||
| assert doc.image2.tensor is None | ||
| assert doc.image2.embedding is None | ||
| assert doc.image2.bytes is None | ||
|
|
||
| assert da[0].image2.url == 'image_10.png' | ||
| assert da[1].image2.url is None | ||
| assert da[2].image2.url is None | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def nested_doc(): | ||
| class Inner(BaseDocument): | ||
| img: Optional[Image] | ||
|
|
||
| class Middle(BaseDocument): | ||
| img: Optional[Image] | ||
| inner: Optional[Inner] | ||
|
|
||
| class Outer(BaseDocument): | ||
| img: Optional[Image] | ||
| middle: Optional[Middle] | ||
|
|
||
| doc = Outer(img=Image(), middle=Middle(img=Image(), inner=Inner(img=Image()))) | ||
| return doc | ||
|
|
||
|
|
||
| def test_from_csv_without_schema_raise_exception(): | ||
| with pytest.raises(TypeError, match='no document schema defined'): | ||
| DocumentArray.from_csv(file_path=str(TOYDATA_DIR / 'docs_nested.csv')) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense that this raise an Expection for now. But what about implementing at this level the auto schema detection ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, for now the exception, I think the auto detection or handing over a schema some other way then setting
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes makes sense to do it in another PR |
||
|
|
||
|
|
||
| def test_from_csv_with_wrong_schema_raise_exception(nested_doc): | ||
| with pytest.raises( | ||
| ValueError, match='Fields provided in the csv file do not match the schema' | ||
| ): | ||
| DocumentArray[nested_doc.__class__].from_csv( | ||
| file_path=str(TOYDATA_DIR / 'docs.csv') | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.