-
Notifications
You must be signed in to change notification settings - Fork 238
Expand file tree
/
Copy pathlist_advance_indexing.py
More file actions
223 lines (179 loc) · 6.64 KB
/
list_advance_indexing.py
File metadata and controls
223 lines (179 loc) · 6.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from typing import (
Any,
Iterable,
List,
Sequence,
TypeVar,
Union,
cast,
no_type_check,
overload,
)
import numpy as np
from typing_extensions import SupportsIndex
from docarray.utils._internal.misc import (
is_jax_available,
is_tf_available,
is_torch_available,
)
torch_available = is_torch_available()
if torch_available:
import torch
tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor
jax_available = is_jax_available()
if jax_available:
import jax.numpy as jnp
from docarray.typing.tensor.jaxarray import JaxArray
T_item = TypeVar('T_item')
T = TypeVar('T', bound='ListAdvancedIndexing')
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]
def _is_np_int(item: Any) -> bool:
dtype = getattr(item, 'dtype', None)
ndim = getattr(item, 'ndim', None)
if dtype is not None and ndim is not None:
try:
return ndim == 0 and np.issubdtype(dtype, np.integer)
except TypeError:
return False
return False # this is unreachable, but mypy wants it
class ListAdvancedIndexing(List[T_item]):
"""
A list wrapper that implements custom indexing
You can index into a ListAdvanceIndex like a numpy array or torch tensor:
---
```python
docs[0] # index by position
docs[0:5:2] # index by slice
docs[[0, 2, 3]] # index by list of indices
docs[True, False, True, True, ...] # index by boolean mask
```
---
"""
@staticmethod
def _normalize_index_item(
item: Any,
) -> Union[int, slice, Iterable[int], Iterable[bool], None]:
# basic index types
if item is None or isinstance(item, (int, slice, tuple, list)):
return item
# numpy index types
if _is_np_int(item):
return item.item()
index_has_getitem = hasattr(item, '__getitem__')
is_valid_bulk_index = index_has_getitem and isinstance(item, Iterable)
if not is_valid_bulk_index:
raise ValueError(f'Invalid index type {type(item)}')
if isinstance(item, np.ndarray) and (
item.dtype == np.bool_ or np.issubdtype(item.dtype, np.integer)
):
return item.tolist()
# torch index types
if torch_available:
allowed_torch_dtypes = [
torch.bool,
torch.int64,
]
if isinstance(item, torch.Tensor) and (item.dtype in allowed_torch_dtypes):
return item.tolist()
if tf_available:
if isinstance(item, tf.Tensor):
return item.numpy().tolist()
if isinstance(item, TensorFlowTensor):
return item.tensor.numpy().tolist()
if jax_available:
if isinstance(item, jnp.ndarray):
return item.__array__().tolist()
if isinstance(item, JaxArray):
return item.tensor.__array__().tolist()
return item
def _get_from_indices(self: T, item: Iterable[int]) -> T:
results = []
for ix in item:
results.append(self[ix])
return self.__class__(results)
def _set_by_indices(self: T, item: Iterable[int], value: Iterable[T_item]):
for ix, doc_to_set in zip(item, value):
try:
self[ix] = doc_to_set
except KeyError:
raise IndexError(f'Index {ix} is out of range')
def _get_from_mask(self: T, item: Iterable[bool]) -> T:
return self.__class__(
[doc for doc, mask_value in zip(self, item) if mask_value]
)
def _set_by_mask(self: T, item: Iterable[bool], value: Sequence[T_item]):
i_value = 0
for i, mask_value in zip(range(len(self)), item):
if mask_value:
self[i] = value[i_value]
i_value += 1
def _del_from_mask(self: T, item: Iterable[bool]) -> None:
idx_to_delete = [i for i, val in enumerate(item) if val]
self._del_from_indices(idx_to_delete)
def _del_from_indices(self: T, item: Iterable[int]) -> None:
for ix in sorted(item, reverse=True):
# reversed is needed here otherwise some the indices are not up to date after
# each delete
del self[ix]
def __delitem__(self, key: Union[SupportsIndex, IndexIterType]) -> None:
item = self._normalize_index_item(key)
if item is None:
return
elif isinstance(item, (int, slice)):
super().__delitem__(item)
else:
head = item[0] # type: ignore
if isinstance(head, bool):
item_ = cast(Iterable[bool], item)
return self._del_from_mask(item_)
elif isinstance(head, int):
return self._del_from_indices(item)
else:
raise TypeError(f'Invalid type {type(head)} for indexing')
@overload
def __getitem__(self: T, item: SupportsIndex) -> T_item:
...
@overload
def __getitem__(self: T, item: IndexIterType) -> T:
...
def __getitem__(self, item):
item = self._normalize_index_item(item)
if type(item) == slice:
return self.__class__(super().__getitem__(item))
if isinstance(item, int):
return super().__getitem__(item)
if item is None:
return self
# _normalize_index_item() guarantees the line below is correct
head = item[0] # type: ignore
if isinstance(head, bool):
return self._get_from_mask(item)
elif isinstance(head, int):
return self._get_from_indices(item)
else:
raise TypeError(f'Invalid type {type(head)} for indexing')
@overload
def __setitem__(self: T, key: SupportsIndex, value: T_item) -> None:
...
@overload
def __setitem__(self: T, key: IndexIterType, value: Iterable[T_item]):
...
@no_type_check
def __setitem__(self: T, key, value):
key_norm = self._normalize_index_item(key)
if isinstance(key_norm, int):
super().__setitem__(key_norm, value)
elif isinstance(key_norm, slice):
super().__setitem__(key_norm, value)
else:
# _normalize_index_item() guarantees the line below is correct
head = key_norm[0]
if isinstance(head, bool):
return self._set_by_mask(key_norm, value)
elif isinstance(head, int):
return self._set_by_indices(key_norm, value)
else:
raise TypeError(f'Invalid type {type(head)} for indexing')