forked from docarray/docarray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathany_array.py
More file actions
353 lines (279 loc) · 12 KB
/
any_array.py
File metadata and controls
353 lines (279 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import sys
import random
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Generic,
Iterable,
List,
MutableSequence,
Sequence,
Type,
TypeVar,
Union,
cast,
overload,
)
import numpy as np
from docarray.base_doc.doc import BaseDocWithoutId
from docarray.display.document_array_summary import DocArraySummary
from docarray.exceptions.exceptions import UnusableObjectError
from docarray.typing.abstract_type import AbstractType
from docarray.utils._internal._typing import change_cls_name, safe_issubclass
from docarray.utils._internal.pydantic import is_pydantic_v2
if TYPE_CHECKING:
from docarray.proto import DocListProto, NodeProto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
if sys.version_info >= (3, 12):
from types import GenericAlias
T = TypeVar('T', bound='AnyDocArray')
T_doc = TypeVar('T_doc', bound=BaseDocWithoutId)
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]
UNUSABLE_ERROR_MSG = (
'This {cls} instance is in an unusable state. \n'
'The most common cause of this is converting a DocVec to a DocList. '
'After you call `doc_vec.to_doc_list()`, `doc_vec` cannot be used anymore. '
'Instead, you should do `doc_list = doc_vec.to_doc_list()` and only use `doc_list`.'
)
class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType):
doc_type: Type[BaseDocWithoutId]
__typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {}
def __repr__(self):
return f'<{self.__class__.__name__} (length={len(self)})>'
@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
if not isinstance(item, type):
if sys.version_info < (3, 12):
return Generic.__class_getitem__.__func__(cls, item) # type: ignore
# this do nothing that checking that item is valid type var or str
# Keep the approach in #1147 to be compatible with lower versions of Python.
else:
return GenericAlias(cls, item) # type: ignore
if not safe_issubclass(item, BaseDocWithoutId):
raise ValueError(
f'{cls.__name__}[item] item should be a Document not a {item} '
)
if cls not in cls.__typed_da__:
cls.__typed_da__[cls] = {}
if item not in cls.__typed_da__[cls]:
# Promote to global scope so multiprocessing can pickle it
global _DocArrayTyped
if not is_pydantic_v2:
class _DocArrayTyped(cls): # type: ignore
doc_type: Type[BaseDocWithoutId] = cast(
Type[BaseDocWithoutId], item
)
else:
class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore
doc_type: Type[BaseDocWithoutId] = cast(
Type[BaseDocWithoutId], item
)
for field in _DocArrayTyped.doc_type._docarray_fields().keys():
def _property_generator(val: str):
def _getter(self):
if getattr(self, '_is_unusable', False):
raise UnusableObjectError(
UNUSABLE_ERROR_MSG.format(cls=cls.__name__)
)
return self._get_data_column(val)
def _setter(self, value):
if getattr(self, '_is_unusable', False):
raise UnusableObjectError(
UNUSABLE_ERROR_MSG.format(cls=cls.__name__)
)
self._set_data_column(val, value)
# need docstring for the property
return property(fget=_getter, fset=_setter)
setattr(_DocArrayTyped, field, _property_generator(field))
# this generates property on the fly based on the schema of the item
# # The global scope and qualname need to refer to this class a unique name.
# # Otherwise, creating another _DocArrayTyped will overwrite this one.
if not is_pydantic_v2:
change_cls_name(
_DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
)
cls.__typed_da__[cls][item] = _DocArrayTyped
else:
change_cls_name(_DocArrayTyped, f'{cls.__name__}', globals())
if sys.version_info < (3, 12):
cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(
_DocArrayTyped, item
) # type: ignore
# this do nothing that checking that item is valid type var or str
# Keep the approach in #1147 to be compatible with lower versions of Python.
else:
cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item) # type: ignore
return cls.__typed_da__[cls][item]
@overload
def __getitem__(self: T, item: int) -> T_doc:
...
@overload
def __getitem__(self: T, item: IndexIterType) -> T:
...
@abstractmethod
def __getitem__(self, item: Union[int, IndexIterType]) -> Union[T_doc, T]:
...
def __getattr__(self, item: str):
# Needs to be explicitly defined here for the purpose to disable PyCharm's complaints
# about not detected properties: https://youtrack.jetbrains.com/issue/PY-47991
return super().__getattribute__(item)
@abstractmethod
def _get_data_column(
self: T,
field: str,
) -> Union[MutableSequence, T, 'AbstractTensor', None]:
"""Return all values of the fields from all docs this array contains
:param field: name of the fields to extract
:return: Returns a list of the field value for each document
in the array like container
"""
...
@abstractmethod
def _set_data_column(
self: T,
field: str,
values: Union[List, T, 'AbstractTensor'],
):
"""Set all Documents in this [`DocList`][docarray.array.doc_list.doc_list.DocList] using the passed values
:param field: name of the fields to extract
:values: the values to set at the DocList level
"""
...
@classmethod
@abstractmethod
def from_protobuf(cls: Type[T], pb_msg: 'DocListProto') -> T:
"""create a Document from a protobuf message"""
...
@abstractmethod
def to_protobuf(self) -> 'DocListProto':
"""Convert DocList into a Protobuf message"""
...
def _to_node_protobuf(self) -> 'NodeProto':
"""Convert a [`DocList`][docarray.array.doc_list.doc_list.DocList] into a NodeProto
protobuf message.
This function should be called when a DocList is nested into
another Document that need to be converted into a protobuf.
:return: the nested item protobuf message
"""
from docarray.proto import NodeProto
return NodeProto(doc_array=self.to_protobuf())
@abstractmethod
def traverse_flat(
self: 'AnyDocArray',
access_path: str,
) -> Union[List[Any], 'AbstractTensor']:
"""
Return a List of the accessed objects when applying the `access_path`. If this
results in a nested list or list of [`DocList`s][docarray.array.doc_list.doc_list.DocList], the list will be flattened
on the first level. The access path is a string that consists of attribute
names, concatenated and `"__"`-separated. It describes the path from the first
level to an arbitrary one, e.g. `'content__image__url'`.
```python
from docarray import BaseDoc, DocList, Text
class Author(BaseDoc):
name: str
class Book(BaseDoc):
author: Author
content: Text
docs = DocList[Book](
Book(author=Author(name='Jenny'), content=Text(text=f'book_{i}'))
for i in range(10) # noqa: E501
)
books = docs.traverse_flat(access_path='content') # list of 10 Text objs
authors = docs.traverse_flat(access_path='author__name') # list of 10 strings
```
If the resulting list is a nested list, it will be flattened:
```python
from docarray import BaseDoc, DocList
class Chapter(BaseDoc):
content: str
class Book(BaseDoc):
chapters: DocList[Chapter]
docs = DocList[Book](
Book(chapters=DocList[Chapter]([Chapter(content='some_content') for _ in range(3)]))
for _ in range(10)
)
chapters = docs.traverse_flat(access_path='chapters') # list of 30 strings
```
If your [`DocList`][docarray.array.doc_list.doc_list.DocList] is in doc_vec mode, and you want to access a field of
type `AnyTensor`, the doc_vec tensor will be returned instead of a list:
```python
class Image(BaseDoc):
tensor: TorchTensor[3, 224, 224]
batch = DocList[Image](
[
Image(
tensor=torch.zeros(3, 224, 224),
)
for _ in range(2)
]
)
batch_stacked = batch.stack()
tensors = batch_stacked.traverse_flat(
access_path='tensor'
) # tensor of shape (2, 3, 224, 224)
```
:param access_path: a string that represents the access path ("__"-separated).
:return: list of the accessed objects, flattened if nested.
"""
...
@staticmethod
def _traverse(node: Any, access_path: str):
if access_path:
curr_attr, _, path_attrs = access_path.partition('__')
from docarray.array import DocList
if isinstance(node, (DocList, list)):
for n in node:
x = getattr(n, curr_attr)
yield from AnyDocArray._traverse(x, path_attrs)
else:
x = getattr(node, curr_attr)
yield from AnyDocArray._traverse(x, path_attrs)
else:
yield node
@staticmethod
def _flatten_one_level(sequence: List[Any]) -> List[Any]:
from docarray import DocList
if len(sequence) == 0 or not isinstance(sequence[0], (list, DocList)):
return sequence
else:
return [item for sublist in sequence for item in sublist]
def summary(self):
"""
Print a summary of this [`DocList`][docarray.array.doc_list.doc_list.DocList] object and a summary of the schema of its
Document type.
"""
DocArraySummary(self).summary()
def _batch(
self: T,
batch_size: int,
shuffle: bool = False,
show_progress: bool = False,
) -> Generator[T, None, None]:
"""
Creates a `Generator` that yields [`DocList`][docarray.array.doc_list.doc_list.DocList] of size `batch_size`.
Note, that the last batch might be smaller than `batch_size`.
:param batch_size: Size of each generated batch.
:param shuffle: If set, shuffle the Documents before dividing into minibatches.
:param show_progress: if set, show a progress bar when batching documents.
:yield: a Generator of [`DocList`][docarray.array.doc_list.doc_list.DocList], each in the length of `batch_size`
"""
from rich.progress import track
if not (isinstance(batch_size, int) and batch_size > 0):
raise ValueError(
f'`batch_size` should be a positive integer, received: {batch_size}'
)
N = len(self)
indices = list(range(N))
n_batches = int(np.ceil(N / batch_size))
if shuffle:
random.shuffle(indices)
for i in track(
range(n_batches),
description='Batching documents',
disable=not show_progress,
):
yield self[indices[i * batch_size : (i + 1) * batch_size]]