-
Notifications
You must be signed in to change notification settings - Fork 244
Expand file tree
/
Copy patharray.py
More file actions
303 lines (244 loc) · 9.24 KB
/
Copy patharray.py
File metadata and controls
303 lines (244 loc) · 9.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import io
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
List,
MutableSequence,
Optional,
Sequence,
Type,
TypeVar,
Union,
overload,
)
from typing_inspect import is_union_type
from docarray.array.abstract_array import AnyDocArray
from docarray.array.array.io import IOMixinArray
from docarray.array.array.pushpull import PushPullMixin
from docarray.array.array.sequence_indexing_mixin import (
IndexingSequenceMixin,
IndexIterType,
)
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.typing import NdArray
if TYPE_CHECKING:
from pydantic import BaseConfig
from pydantic.fields import ModelField
from docarray.array.stacked.array_stacked import DocArrayStacked
from docarray.proto import DocumentArrayProto
from docarray.typing import TorchTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
T = TypeVar('T', bound='DocArray')
T_doc = TypeVar('T_doc', bound=BaseDoc)
def _delegate_meth_to_data(meth_name: str) -> Callable:
"""
create a function that mimic a function call to the data attribute of the
DocArray
:param meth_name: name of the method
:return: a method that mimic the meth_name
"""
func = getattr(list, meth_name)
@wraps(func)
def _delegate_meth(self, *args, **kwargs):
return getattr(self._data, meth_name)(*args, **kwargs)
return _delegate_meth
class DocArray(
IndexingSequenceMixin[T_doc], PushPullMixin, IOMixinArray, AnyDocArray[T_doc]
):
"""
DocArray is a container of Documents.
A DocArray is a list of Documents of any schema. However, many
DocArray features are only available if these Documents are
homogeneous and follow the same schema. To precise this schema you can use
the `DocArray[MyDocument]` syntax where MyDocument is a Document class
(i.e. schema). This creates a DocArray that can only contains Documents of
the type 'MyDocument'.
---
```python
from docarray import BaseDoc, DocArray
from docarray.typing import NdArray, ImageUrl
from typing import Optional
class Image(BaseDoc):
tensor: Optional[NdArray[100]]
url: ImageUrl
da = DocArray[Image](
Image(url='http://url.com/foo.png') for _ in range(10)
) # noqa: E510
```
---
If your DocArray is homogeneous (i.e. follows the same schema), you can access
fields at the DocArray level (for example `da.tensor` or `da.url`).
You can also set fields, with `da.tensor = np.random.random([10, 100])`:
print(da.url)
# [ImageUrl('http://url.com/foo.png', host_type='domain'), ...]
import numpy as np
da.tensor = np.random.random([10, 100])
print(da.tensor)
# [NdArray([0.11299577, 0.47206767, 0.481723 , 0.34754724, 0.15016037,
# 0.88861321, 0.88317666, 0.93845579, 0.60486676, ... ]), ...]
You can index into a DocArray like a numpy array or torch tensor:
da[0] # index by position
da[0:5:2] # index by slice
da[[0, 2, 3]] # index by list of indices
da[True, False, True, True, ...] # index by boolean mask
You can delete items from a DocArray like a Python List
del da[0] # remove first element from DocArray
del da[0:5] # remove elements for 0 to 5 from DocArray
:param docs: iterable of Document
"""
document_type: Type[BaseDoc] = AnyDoc
def __init__(
self,
docs: Optional[Iterable[T_doc]] = None,
):
self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else []
@classmethod
def construct(
cls: Type[T],
docs: Sequence[T_doc],
) -> T:
"""
Create a DocArray without validation any data. The data must come from a
trusted source
:param docs: a Sequence (list) of Document with the same schema
:return:
"""
da = cls.__new__(cls)
da._data = docs if isinstance(docs, list) else list(docs)
return da
def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:
"""
Validate if an Iterable of Document are compatible with this DocArray
"""
for doc in docs:
yield self._validate_one_doc(doc)
def _validate_one_doc(self, doc: T_doc) -> T_doc:
"""Validate if a Document is compatible with this DocArray"""
if not issubclass(self.document_type, AnyDoc) and not isinstance(
doc, self.document_type
):
raise ValueError(f'{doc} is not a {self.document_type}')
return doc
def __len__(self):
return len(self._data)
def __iter__(self):
return iter(self._data)
def __bytes__(self) -> bytes:
with io.BytesIO() as bf:
self._write_bytes(bf=bf)
return bf.getvalue()
def append(self, doc: T_doc):
"""
Append a Document to the DocArray. The Document must be from the same class
as the document_type of this DocArray otherwise it will fail.
:param doc: A Document
"""
self._data.append(self._validate_one_doc(doc))
def extend(self, docs: Iterable[T_doc]):
"""
Extend a DocArray with an Iterable of Document. The Documents must be from
the same class as the document_type of this DocArray otherwise it will
fail.
:param docs: Iterable of Documents
"""
self._data.extend(self._validate_docs(docs))
def insert(self, i: int, doc: T_doc):
"""
Insert a Document to the DocArray. The Document must be from the same
class as the document_type of this DocArray otherwise it will fail.
:param i: index to insert
:param doc: A Document
"""
self._data.insert(i, self._validate_one_doc(doc))
pop = _delegate_meth_to_data('pop')
remove = _delegate_meth_to_data('remove')
reverse = _delegate_meth_to_data('reverse')
sort = _delegate_meth_to_data('sort')
def _get_data_column(
self: T,
field: str,
) -> Union[MutableSequence, T, 'TorchTensor', 'NdArray']:
"""Return all values of the fields from all docs this array contains
:param field: name of the fields to extract
:return: Returns a list of the field value for each document
in the array like container
"""
field_type = self.__class__.document_type._get_field_type(field)
if (
not is_union_type(field_type)
and isinstance(field_type, type)
and issubclass(field_type, BaseDoc)
):
# calling __class_getitem__ ourselves is a hack otherwise mypy complain
# most likely a bug in mypy though
# bug reported here https://github.com/python/mypy/issues/14111
return DocArray.__class_getitem__(field_type)(
(getattr(doc, field) for doc in self),
)
else:
return [getattr(doc, field) for doc in self]
def _set_data_column(
self: T,
field: str,
values: Union[List, T, 'AbstractTensor'],
):
"""Set all Documents in this DocArray using the passed values
:param field: name of the fields to set
:values: the values to set at the DocArray level
"""
...
for doc, value in zip(self, values):
setattr(doc, field, value)
def stack(
self,
tensor_type: Type['AbstractTensor'] = NdArray,
) -> 'DocArrayStacked':
"""
Convert the DocArray into a DocArrayStacked. `Self` cannot be used
afterwards
:param tensor_type: Tensor Class used to wrap the stacked tensors. This is useful
if the BaseDoc has some undefined tensor type like AnyTensor or Union of NdArray and TorchTensor
:return: A DocArrayStacked of the same document type as self
"""
from docarray.array.stacked.array_stacked import DocArrayStacked
return DocArrayStacked.__class_getitem__(self.document_type)(
self, tensor_type=tensor_type
)
@classmethod
def validate(
cls: Type[T],
value: Union[T, Iterable[BaseDoc]],
field: 'ModelField',
config: 'BaseConfig',
):
from docarray.array.stacked.array_stacked import DocArrayStacked
if isinstance(value, (cls, DocArrayStacked)):
return value
elif isinstance(value, Iterable):
return cls(value)
else:
raise TypeError(f'Expecting an Iterable of {cls.document_type}')
def traverse_flat(
self: 'DocArray',
access_path: str,
) -> List[Any]:
nodes = list(AnyDocArray._traverse(node=self, access_path=access_path))
flattened = AnyDocArray._flatten_one_level(nodes)
return flattened
@classmethod
def from_protobuf(cls: Type[T], pb_msg: 'DocumentArrayProto') -> T:
"""create a Document from a protobuf message
:param pb_msg: The protobuf message from where to construct the DocArray
"""
return super().from_protobuf(pb_msg)
@overload
def __getitem__(self, item: int) -> T_doc:
...
@overload
def __getitem__(self: T, item: IndexIterType) -> T:
...
def __getitem__(self, item):
return super().__getitem__(item)