-
Notifications
You must be signed in to change notification settings - Fork 238
Expand file tree
/
Copy pathsetitem.py
More file actions
249 lines (221 loc) · 8.59 KB
/
setitem.py
File metadata and controls
249 lines (221 loc) · 8.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import itertools
from typing import (
TYPE_CHECKING,
Union,
Sequence,
overload,
Any,
List,
)
import numpy as np
from docarray import Document
from docarray.helper import typename
if TYPE_CHECKING: # pragma: no cover
from docarray.typing import (
DocumentArrayIndexType,
DocumentArraySingletonIndexType,
DocumentArrayMultipleIndexType,
DocumentArrayMultipleAttributeType,
DocumentArraySingleAttributeType,
)
class SetItemMixin:
"""Provides helper function to allow advanced indexing for `__setitem__`"""
@overload
def __setitem__(
self,
index: 'DocumentArrayMultipleAttributeType',
value: List[List['Any']],
):
...
@overload
def __setitem__(
self,
index: 'DocumentArraySingleAttributeType',
value: List['Any'],
):
...
@overload
def __setitem__(
self,
index: 'DocumentArraySingletonIndexType',
value: 'Document',
):
...
@overload
def __setitem__(
self,
index: 'DocumentArrayMultipleIndexType',
value: Sequence['Document'],
):
...
def __setitem__(
self,
index: 'DocumentArrayIndexType',
value: Union['Document', Sequence['Document']],
):
from docarray.helper import check_root_id
if getattr(self, '_is_subindex', None):
check_root_id(self, value)
self._update_subindices_set(index, value)
# set by offset
# allows da[1] = Document()
if isinstance(index, (int, np.generic)) and not isinstance(index, bool):
self._set_doc_by_offset(int(index), value)
elif isinstance(index, str):
# set by traversal paths
# allows da['@m,c] = [m1, m2, ..., mn, c1, c2, ..., cp]
if index.startswith('@'):
self._set_doc_value_pairs_nested(self.traverse_flat(index[1:]), value)
# set by ID
# allows da['id_123'] = Document()
else:
self._set_doc(index, value)
# set by slice
# allows da[1:3] = [d1, d2]
elif isinstance(index, slice):
self._set_docs_by_slice(index, value)
# flatten and set
# allows da[...] = [d1, d2,..., dn]
elif index is Ellipsis:
self._set_doc_value_pairs(self.flatten(), value)
# index is sequence
elif isinstance(index, Sequence):
# allows da[idx1, idx2] = value
if isinstance(index, tuple) and len(index) == 2:
self._set_by_pair(index[0], index[1], value)
# allows da[True, False, True, True]
elif isinstance(index[0], bool):
self._set_by_mask(index, value)
# allows da[id1, id2, id3] = [d1, d2, d3]
elif isinstance(index[0], (int, str)):
for si, _val in zip(index, value):
self[si] = _val # leverage existing setter
else:
raise IndexError(
f'{index} should be either a sequence of bool, int or str'
)
# set by ndarray
elif isinstance(index, np.ndarray):
index = index.squeeze()
if index.ndim == 1:
self[index.tolist()] = value # leverage existing setter
else:
raise IndexError(
f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}'
)
else:
raise IndexError(f'Unsupported index type {typename(index)}: {index}')
def _set_by_pair(self, idx1, idx2, value):
if isinstance(idx1, str) and not idx1.startswith('@'):
# second is an ID
# allows da[id1, id2] = [d1, d2]
if isinstance(idx2, str) and idx2 in self:
self._set_doc_value_pairs((self[idx1], self[idx2]), value)
# second is an attribute
# allows da[id, attr] = attr_value
elif isinstance(idx2, str) and hasattr(self[idx1], idx2):
self._set_doc_attr_by_id(idx1, idx2, value)
# second is a list of attributes:
# allows da[id, [attr1, attr2, attr3]] = [v1, v2, v3]
elif (
isinstance(idx2, Sequence)
and all(isinstance(attr, str) for attr in idx2)
and all(hasattr(self[idx1], attr) for attr in idx2)
):
for attr, _v in zip(idx2, value):
self._set_doc_attr_by_id(idx1, attr, _v)
else:
raise IndexError(f'`{idx2}` is neither a valid id nor attribute name')
elif isinstance(idx1, int):
# second is an offset
# allows da[offset1, offset2] = [d1, d2]
if isinstance(idx2, int):
self._set_doc_value_pairs((self[idx1], self[idx2]), value)
# second is an attribute
# allows da[offset, attr] = value
elif isinstance(idx2, str) and hasattr(self[idx1], idx2):
self._set_doc_attr_by_offset(idx1, idx2, value)
# second is a list of attributes
# allows da[offset, [attr1, attr2, attr3]] = [v1, v2, v3]
elif (
isinstance(idx2, Sequence)
and all(isinstance(attr, str) for attr in idx2)
and all(hasattr(self[idx1], attr) for attr in idx2)
):
for attr, _v in zip(idx2, value):
self._set_doc_attr_by_offset(idx1, attr, _v)
else:
raise IndexError(f'`{idx2}` must be an attribute or list of attributes')
# allows da[sequence/slice/ellipsis/traversal_path, attributes] = [v1, v2, ...]
elif (
isinstance(idx1, (slice, Sequence))
or idx1 is Ellipsis
or (isinstance(idx1, str) and idx1.startswith('@'))
):
self._set_docs_attributes(idx1, idx2, value)
else:
raise IndexError(f'Unsupported first index type {typename(idx1)}: {idx1}')
def _set_by_mask(self, mask: List[bool], value):
_selected = itertools.compress(self, mask)
self._set_doc_value_pairs(_selected, value)
def _set_docs_attributes(self, index, attributes, value):
if isinstance(attributes, str):
# a -> [a]
# [a, a] -> [a, a]
attributes = (attributes,)
value = (value,)
if isinstance(index, str) and index.startswith('@'):
self._set_docs_attributes_traversal_paths(index, attributes, value)
elif index is Ellipsis:
_docs = self[index]
for _a, _v in zip(attributes, value):
if _a == 'tensor':
_docs.tensors = _v
elif _a == 'embedding':
_docs.embeddings = _v
else:
if not isinstance(_v, (list, tuple)):
for _d in _docs:
setattr(_d, _a, _v)
else:
for _d, _vv in zip(_docs, _v):
setattr(_d, _a, _vv)
self._set_doc_value_pairs_nested(_docs, _docs)
else:
_docs = self[index]
if not _docs:
return
for _a, _v in zip(attributes, value):
if _a in ('tensor', 'embedding'):
if _a == 'tensor':
_docs.tensors = _v
elif _a == 'embedding':
_docs.embeddings = _v
for _d in _docs:
self._set_doc(_d.id, _d)
else:
if not isinstance(_v, (list, tuple)):
for _d in _docs:
self._set_doc_attr_by_id(_d.id, _a, _v)
else:
for _d, _vv in zip(_docs, _v):
self._set_doc_attr_by_id(_d.id, _a, _vv)
def _set_docs_attributes_traversal_paths(
self, traversal_paths: str, attributes, value
):
_docs = self[traversal_paths]
if not _docs:
return
for _a, _v in zip(attributes, value):
if _a == 'tensor':
_docs.tensors = _v
elif _a == 'embedding':
_docs.embeddings = _v
else:
if not isinstance(_v, (list, tuple)):
for _d in _docs:
setattr(_d, _a, _v)
else:
for _d, _vv in zip(_docs, _v):
setattr(_d, _a, _vv)
self._set_doc_value_pairs_nested(_docs, _docs)