-
Notifications
You must be signed in to change notification settings - Fork 238
Expand file tree
/
Copy pathtest_base_document.py
More file actions
189 lines (130 loc) · 4.89 KB
/
test_base_document.py
File metadata and controls
189 lines (130 loc) · 4.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from typing import Any, List, Optional, Tuple
import numpy as np
import orjson
import pytest
from docarray import DocList, DocVec
from docarray.base_doc.doc import BaseDoc
from docarray.base_doc.io.json import orjson_dumps_and_decode
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.pydantic import is_pydantic_v2
def test_base_document_init():
doc = BaseDoc()
assert doc.id is not None
def test_update():
class MyDocument(BaseDoc):
content: str
title: Optional[str] = None
tags_: List
doc1 = MyDocument(
content='Core content of the document', title='Title', tags_=['python', 'AI']
)
doc2 = MyDocument(content='Core content updated', tags_=['docarray'])
doc1.update(doc2)
assert doc1.content == 'Core content updated'
assert doc1.title == 'Title'
assert doc1.tags_ == ['python', 'AI', 'docarray']
def test_equal_nested_docs():
import numpy as np
from docarray import BaseDoc, DocList
from docarray.typing import NdArray
class SimpleDoc(BaseDoc):
simple_tens: NdArray[10]
class NestedDoc(BaseDoc):
docs: DocList[SimpleDoc]
nested_docs = NestedDoc(
docs=DocList[SimpleDoc]([SimpleDoc(simple_tens=np.ones(10)) for j in range(2)]),
)
assert nested_docs == nested_docs
@pytest.fixture
def nested_docs():
class SimpleDoc(BaseDoc):
simple_tens: NdArray[10]
class NestedDoc(BaseDoc):
docs: DocList[SimpleDoc]
hello: str = 'world'
nested_docs = NestedDoc(
docs=DocList[SimpleDoc]([SimpleDoc(simple_tens=np.ones(10)) for j in range(2)]),
)
return nested_docs
@pytest.fixture
def nested_docs_docvec():
class SimpleDoc(BaseDoc):
simple_tens: NdArray[10]
class NestedDoc(BaseDoc):
docs: DocVec[SimpleDoc]
hello: str = 'world'
nested_docs = NestedDoc(
docs=DocList[SimpleDoc]([SimpleDoc(simple_tens=np.ones(10)) for j in range(2)]),
)
return nested_docs
def test_nested_to_dict(nested_docs):
d = nested_docs.dict()
assert (d['docs'][0]['simple_tens'] == np.ones(10)).all()
assert isinstance(d['docs'], list)
assert not isinstance(d['docs'], DocList)
def test_nested_docvec_to_dict(nested_docs_docvec):
d = nested_docs_docvec.dict()
assert (d['docs'][0]['simple_tens'] == np.ones(10)).all()
def test_nested_to_dict_exclude(nested_docs):
d = nested_docs.dict(exclude={'docs'})
assert 'docs' not in d.keys()
def test_nested_to_dict_exclude_set(nested_docs):
d = nested_docs.dict(exclude={'hello'})
assert 'hello' not in d.keys()
def test_nested_to_dict_exclude_dict(nested_docs):
d = nested_docs.dict(exclude={'hello': True})
assert 'hello' not in d.keys()
def test_nested_to_json(nested_docs):
d = nested_docs.json()
nested_docs.__class__.parse_raw(d)
@pytest.fixture
def nested_none_docs():
class SimpleDoc(BaseDoc):
simple_tens: NdArray[10]
class NestedDoc(BaseDoc):
docs: Optional[DocList[SimpleDoc]] = None
hello: str = 'world'
nested_docs = NestedDoc()
return nested_docs
def test_nested_none_to_dict(nested_none_docs):
d = nested_none_docs.dict()
assert d == {'docs': None, 'hello': 'world', 'id': nested_none_docs.id}
def test_nested_none_to_json(nested_none_docs):
d = nested_none_docs.json()
d = nested_none_docs.__class__.parse_raw(d)
assert d.dict() == {'docs': None, 'hello': 'world', 'id': nested_none_docs.id}
def test_get_get_field_inner_type():
class MyDoc(BaseDoc):
tuple_: Tuple
field_type = MyDoc._get_field_inner_type("tuple_")
assert field_type == Any
@pytest.mark.skipif(
is_pydantic_v2, reason="syntax only working with pydantic v1 for now"
)
def test_subclass_config():
class MyDoc(BaseDoc):
x: str
class Config(BaseDoc.Config):
arbitrary_types_allowed = True # just an example setting
assert MyDoc.Config.json_loads == orjson.loads
assert MyDoc.Config.json_dumps == orjson_dumps_and_decode
assert (
MyDoc.Config.json_encoders[AbstractTensor](3) == 3
) # dirty check that it is identity
assert MyDoc.Config.validate_assignment
assert not MyDoc.Config._load_extra_fields_from_protobuf
assert MyDoc.Config.arbitrary_types_allowed
@pytest.mark.skipif(not (is_pydantic_v2), reason="syntax only working with pydantic v2")
def test_subclass_config_v2():
class MyDoc(BaseDoc):
x: str
model_config = BaseDoc.ConfigDocArray(
arbitrary_types_allowed=True
) # just an example setting
assert (
MyDoc.model_config['json_encoders'][AbstractTensor](3) == 3
) # dirty check that it is identity
assert MyDoc.model_config['validate_assignment']
assert not MyDoc.model_config['_load_extra_fields_from_protobuf']
assert MyDoc.model_config['arbitrary_types_allowed']