forked from docarray/docarray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathndarray.py
More file actions
147 lines (110 loc) · 4.45 KB
/
ndarray.py
File metadata and controls
147 lines (110 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from typing import TYPE_CHECKING, Tuple, Sequence, Optional
import numpy as np
if TYPE_CHECKING:
from ..types import ArrayType
from .. import Document
def unravel(docs: Sequence['Document'], field: str) -> Optional['ArrayType']:
_first = getattr(docs[0], field)
if _first is None:
# failed to unravel, return as a list
r = [getattr(d, field) for d in docs]
if any(_rr is not None for _rr in r):
return r
else:
return None
framework, is_sparse = get_array_type(_first)
all_fields = [getattr(d, field) for d in docs]
cls_type = type(_first)
if framework == 'python':
return cls_type(all_fields)
elif framework == 'numpy':
return np.stack(all_fields)
elif framework == 'tensorflow':
import tensorflow as tf
return tf.stack(all_fields)
elif framework == 'torch':
import torch
return torch.stack(all_fields)
elif framework == 'paddle':
import paddle
return paddle.stack(all_fields)
elif framework == 'scipy':
import scipy.sparse
return cls_type(scipy.sparse.vstack(all_fields))
def ravel(value: 'ArrayType', docs: Sequence['Document'], field: str) -> None:
"""Ravel :attr:`value` into ``doc.field`` of each documents
:param docs: the docs to set
:param field: the field of the doc to set
:param value: the value to be set on ``doc.field``
"""
use_get_row = False
if hasattr(value, 'getformat'):
# for scipy only
sp_format = value.getformat()
if sp_format in {'bsr', 'coo'}:
# for BSR and COO, they dont implement [j, ...] in scipy
# but they offer get_row() API which implicitly translate the
# sparse row into CSR format, hence needs to convert back
# not very efficient, but this is the best we can do.
use_get_row = True
if use_get_row:
emb_shape0 = value.shape[0]
for d, j in zip(docs, range(emb_shape0)):
row = getattr(value.getrow(j), f'to{sp_format}')()
setattr(d, field, row)
elif isinstance(value, (list, tuple)):
for d, j in zip(docs, value):
setattr(d, field, j)
else:
emb_shape0 = value.shape[0]
for d, j in zip(docs, range(emb_shape0)):
setattr(d, field, value[j, ...])
def get_array_type(array: 'ArrayType') -> Tuple[str, bool]:
"""Get the type of ndarray without importing the framework
:param array: any array, scipy, numpy, tf, torch, etc.
:return: a tuple where the first element represents the framework, the second represents if it is sparse array
"""
module_tags = array.__class__.__module__.split('.')
class_name = array.__class__.__name__
if isinstance(array, (list, tuple)):
return 'python', False
if 'numpy' in module_tags:
return 'numpy', False
if 'jina' in module_tags:
if class_name == 'NdArray':
return 'jina', False # sparse or not is irrelevant
if 'docarray_pb2' in module_tags:
if class_name == 'NdArrayProto':
return 'jina_proto', False # sparse or not is irrelevant
if 'tensorflow' in module_tags:
if class_name == 'SparseTensor':
return 'tensorflow', True
if class_name == 'Tensor' or class_name == 'EagerTensor':
return 'tensorflow', False
if 'torch' in module_tags and class_name == 'Tensor':
return 'torch', array.is_sparse
if 'paddle' in module_tags and class_name == 'Tensor':
# Paddle does not support sparse tensor on 11/8/2021
# https://github.com/PaddlePaddle/Paddle/issues/36697
return 'paddle', False
if 'scipy' in module_tags and 'sparse' in module_tags:
return 'scipy', True
raise TypeError(f'can not determine the array type: {module_tags}.{class_name}')
def to_numpy_array(value) -> 'np.ndarray':
"""Return the value always in :class:`numpy.ndarray` regardless the framework type.
:return: the value in :class:`numpy.ndarray`.
"""
v = value
framework, is_sparse = get_array_type(value)
if is_sparse:
if hasattr(v, 'todense'):
v = v.todense()
elif hasattr(v, 'to_dense'):
v = v.to_dense()
elif framework == 'tensorflow':
import tensorflow as tf
if isinstance(v, tf.SparseTensor):
v = tf.sparse.to_dense(v)
if hasattr(v, 'numpy'):
v = v.numpy()
return v