-
Notifications
You must be signed in to change notification settings - Fork 237
Expand file tree
/
Copy pathmisc.py
More file actions
130 lines (101 loc) · 3.15 KB
/
misc.py
File metadata and controls
130 lines (101 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import importlib
import os
import re
import types
from typing import Any, Optional, Literal
import numpy as np
try:
import torch # noqa: F401
except ImportError:
torch_imported = False
else:
torch_imported = True
try:
import tensorflow as tf # type: ignore # noqa: F401
except (ImportError, TypeError):
tf_imported = False
else:
tf_imported = True
try:
import jax.numpy as jnp # type: ignore # noqa: F401
except (ImportError, TypeError):
jnp_imported = False
else:
jnp_imported = True
INSTALL_INSTRUCTIONS = {
'google.protobuf': '"docarray[proto]"',
'lz4': '"docarray[proto]"',
'pandas': '"docarray[pandas]"',
'PIL': '"docarray[image]"',
'pydub': '"docarray[audio]"',
'av': '"docarray[video]"',
'trimesh': '"docarray[mesh]"',
'hnswlib': '"docarray[hnswlib]"',
'elasticsearch': '"docarray[elasticsearch]"',
'elastic_transport': '"docarray[elasticsearch]"',
'weaviate': '"docarray[weaviate]"',
'qdrant_client': '"docarray[qdrant]"',
'fastapi': '"docarray[web]"',
'torch': '"docarray[torch]"',
'tensorflow': 'protobuf==3.19.0 tensorflow',
'smart_open': '"docarray[aws]"',
'boto3': '"docarray[aws]"',
'botocore': '"docarray[aws]"',
'redis': '"docarray[redis]"',
'pymilvus': '"docarray[milvus]"',
}
ProtocolType = Literal[
'protobuf', 'pickle', 'json', 'json-array', 'protobuf-array', 'pickle-array'
]
def import_library(
package: str, raise_error: bool = True
) -> Optional[types.ModuleType]:
lib: Optional[types.ModuleType]
try:
lib = importlib.import_module(package)
except (ModuleNotFoundError, ImportError):
lib = None
if lib is None and raise_error:
raise ImportError(
f'The following required library is not installed: {package} \n'
f'To install all necessary libraries, run: `pip install {INSTALL_INSTRUCTIONS[package]}`.'
)
else:
return lib
def _get_path_from_docarray_root_level(file_path: str) -> str:
path = os.path.dirname(file_path)
rel_path = re.sub('(?s:.*)docarray', 'docarray', path).replace('/', '.')
return rel_path
def is_torch_available():
return torch_imported
def is_tf_available():
return tf_imported
def is_jax_available():
return jnp_imported
def is_np_int(item: Any) -> bool:
dtype = getattr(item, 'dtype', None)
ndim = getattr(item, 'ndim', None)
if dtype is not None and ndim is not None:
try:
return ndim == 0 and np.issubdtype(dtype, np.integer)
except TypeError:
return False
return False # this is unreachable, but mypy wants it
def is_notebook() -> bool:
"""
Check if we're running in a Jupyter notebook, using magic command
`get_ipython` that only available in Jupyter.
:return: True if run in a Jupyter notebook else False.
"""
try:
shell = get_ipython().__class__.__name__ # type: ignore
except NameError:
return False
if shell == 'ZMQInteractiveShell':
return True
elif shell == 'Shell':
return True
elif shell == 'TerminalInteractiveShell':
return False
else:
return False