-
Notifications
You must be signed in to change notification settings - Fork 237
Expand file tree
/
Copy pathndarray.py
More file actions
155 lines (125 loc) · 5.08 KB
/
ndarray.py
File metadata and controls
155 lines (125 loc) · 5.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import TYPE_CHECKING, Optional
import numpy as np
from docarray.math.ndarray import get_array_type, to_numpy_array
if TYPE_CHECKING: # pragma: no cover
from docarray.typing import ArrayType
from docarray.proto.docarray_pb2 import NdArrayProto
def read_ndarray(pb_msg: 'NdArrayProto') -> 'ArrayType':
is_sparse = pb_msg.WhichOneof('content') == 'sparse'
framework = pb_msg.cls_name
if is_sparse:
if framework == 'scipy':
idx, val, shape = _get_raw_sparse_array(pb_msg)
from scipy.sparse import coo_matrix
x = coo_matrix((val, idx.T), shape=shape)
sp_format = pb_msg.parameters['sparse_format']
if sp_format == 'bsr':
return x.tobsr()
elif sp_format == 'csc':
return x.tocsc()
elif sp_format == 'csr':
return x.tocsr()
elif sp_format == 'coo':
return x
elif framework == 'tensorflow':
idx, val, shape = _get_raw_sparse_array(pb_msg)
from tensorflow import SparseTensor
return SparseTensor(idx, val, shape)
elif framework == 'torch':
idx, val, shape = _get_raw_sparse_array(pb_msg)
from torch import sparse_coo_tensor
return sparse_coo_tensor(idx, val, shape)
else:
if framework in {'numpy', 'torch', 'paddle', 'tensorflow', 'list'}:
x = _get_dense_array(pb_msg.dense)
return _to_framework_array(x, framework)
def flush_ndarray(
pb_msg: 'NdArrayProto', value: 'ArrayType', ndarray_type: Optional[str] = None
):
if ndarray_type == 'list':
value = to_numpy_array(value).tolist()
elif ndarray_type == 'numpy':
value = to_numpy_array(value)
framework, is_sparse = get_array_type(value)
if framework == 'docarray':
# it is Jina's NdArray, simply copy it
pb_msg.cls_name = 'numpy'
pb_msg.CopyFrom(value)
elif framework == 'docarray_proto':
pb_msg.cls_name = 'numpy'
pb_msg.CopyFrom(value)
else:
if is_sparse:
if framework == 'scipy':
pb_msg.parameters['sparse_format'] = value.getformat()
_set_scipy_sparse(pb_msg, value)
if framework == 'tensorflow':
_set_tf_sparse(pb_msg, value)
if framework == 'torch':
_set_torch_sparse(pb_msg, value)
else:
if framework == 'numpy':
pb_msg.cls_name = 'numpy'
_set_dense_array(pb_msg.dense, value)
if framework == 'python':
pb_msg.cls_name = 'list'
_set_dense_array(pb_msg.dense, np.array(value))
if framework == 'tensorflow':
pb_msg.cls_name = 'tensorflow'
_set_dense_array(pb_msg.dense, value.numpy())
if framework == 'torch':
pb_msg.cls_name = 'torch'
_set_dense_array(pb_msg.dense, value.detach().cpu().numpy())
if framework == 'paddle':
pb_msg.cls_name = 'paddle'
_set_dense_array(pb_msg.dense, value.numpy())
def _set_dense_array(pb_msg, value):
pb_msg.buffer = value.tobytes()
pb_msg.ClearField('shape')
pb_msg.shape.extend(list(value.shape))
pb_msg.dtype = value.dtype.str
def _set_scipy_sparse(pb_msg, value):
v = value.tocoo(copy=True)
indices = np.stack([v.row, v.col], axis=1)
_set_dense_array(pb_msg.sparse.indices, indices)
_set_dense_array(pb_msg.sparse.values, v.data)
pb_msg.sparse.ClearField('shape')
pb_msg.sparse.shape.extend(v.shape)
pb_msg.cls_name = 'scipy'
def _set_tf_sparse(pb_msg, value):
_set_dense_array(pb_msg.sparse.indices, value.indices.numpy())
_set_dense_array(pb_msg.sparse.values, value.values.numpy())
pb_msg.sparse.ClearField('shape')
pb_msg.sparse.shape.extend(value.shape)
pb_msg.cls_name = 'tensorflow'
def _set_torch_sparse(pb_msg, value):
_set_dense_array(pb_msg.sparse.indices, value.coalesce().indices().numpy())
_set_dense_array(pb_msg.sparse.values, value.coalesce().values().numpy())
pb_msg.sparse.ClearField('shape')
pb_msg.sparse.shape.extend(list(value.size()))
pb_msg.cls_name = 'torch'
def _get_raw_sparse_array(pb_msg):
idx = _get_dense_array(pb_msg.sparse.indices)
val = _get_dense_array(pb_msg.sparse.values)
shape = list(pb_msg.sparse.shape)
return idx, val, shape
def _get_dense_array(source):
if source.buffer:
x = np.frombuffer(source.buffer, dtype=source.dtype)
return x.reshape(source.shape)
elif len(source.shape) > 0:
return np.zeros(source.shape)
def _to_framework_array(x, framework):
if framework == 'numpy':
return x
elif framework == 'tensorflow':
from tensorflow import convert_to_tensor
return convert_to_tensor(x)
elif framework == 'torch':
from torch import from_numpy
return from_numpy(x)
elif framework == 'paddle':
from paddle import to_tensor
return to_tensor(x)
elif framework == 'list':
return x.tolist()