-
Notifications
You must be signed in to change notification settings - Fork 244
Expand file tree
/
Copy pathcolumn_storage.py
More file actions
110 lines (87 loc) · 3.59 KB
/
Copy pathcolumn_storage.py
File metadata and controls
110 lines (87 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from collections import ChainMap
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
MutableMapping,
Type,
TypeVar,
Union,
)
from docarray.array.stacked.list_advance_indexing import ListAdvancedIndexing
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
if TYPE_CHECKING:
from docarray.array.stacked.array_stacked import DocArrayStacked
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]
T = TypeVar('T', bound='ColumnStorage')
class ColumnStorage:
"""
ColumnStorage is a container to store the columns of the
:class:`~docarray.array.stacked.DocArrayStacked`.
:param tensor_columns: a Dict of AbstractTensor
:param doc_columns: a Dict of :class:`~docarray.array.stacked.DocArrayStacked`
:param da_columns: a Dict of List of :class:`~docarray.array.stacked.DocArrayStacked`
:param any_columns: a Dict of List
:param tensor_type: Class used to wrap the stacked tensors
"""
def __init__(
self,
tensor_columns: Dict[str, AbstractTensor],
doc_columns: Dict[str, 'DocArrayStacked'],
da_columns: Dict[str, ListAdvancedIndexing['DocArrayStacked']],
any_columns: Dict[str, ListAdvancedIndexing],
tensor_type: Type[AbstractTensor] = NdArray,
):
self.tensor_columns = tensor_columns
self.doc_columns = doc_columns
self.da_columns = da_columns
self.any_columns = any_columns
self.tensor_type = tensor_type
self.columns = ChainMap( # type: ignore
self.tensor_columns, # type: ignore
self.doc_columns, # type: ignore
self.da_columns, # type: ignore
self.any_columns, # type: ignore
) # type: ignore
def __len__(self) -> int:
return len(self.any_columns['id']) # TODO what if ID are None ?
def __getitem__(self: T, item: IndexIterType) -> T:
if isinstance(item, tuple):
item = list(item)
tensor_columns = {key: col[item] for key, col in self.tensor_columns.items()}
doc_columns = {key: col[item] for key, col in self.doc_columns.items()}
da_columns = {key: col[item] for key, col in self.da_columns.items()}
any_columns = {key: col[item] for key, col in self.any_columns.items()}
return self.__class__(
tensor_columns,
doc_columns,
da_columns,
any_columns,
self.tensor_type,
)
class ColumnStorageView(dict, MutableMapping[str, Any]):
index: int
storage: ColumnStorage
def __init__(self, index: int, storage: ColumnStorage):
super().__init__()
self.index = index
self.storage = storage
def __getitem__(self, name: str) -> Any:
if name in self.storage.tensor_columns.keys():
tensor = self.storage.tensor_columns[name]
if tensor.get_comp_backend().n_dim(tensor) == 1:
# to ensure consistensy between numpy and pytorch
# we wrap the scalr in a tensor of ndim = 1
# otherwise numpy pass by value whereas torch by reference
return self.storage.tensor_columns[name][self.index : self.index + 1]
return self.storage.columns[name][self.index]
def __setitem__(self, name, value) -> None:
self.storage.columns[name][self.index] = value
def __delitem__(self, key):
raise RuntimeError('Cannot delete an item from a StorageView')
def __iter__(self):
return self.storage.columns.keys()
def __len__(self):
return len(self.storage.columns)