-
Notifications
You must be signed in to change notification settings - Fork 237
Expand file tree
/
Copy path__init__.py
More file actions
114 lines (101 loc) · 3.16 KB
/
__init__.py
File metadata and controls
114 lines (101 loc) · 3.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing_extensions import TYPE_CHECKING
from docarray.typing.bytes import AudioBytes, ImageBytes, VideoBytes
from docarray.typing.id import ID
from docarray.typing.tensor import ImageNdArray, ImageTensor
from docarray.typing.tensor.audio import AudioNdArray, AudioTensor
from docarray.typing.tensor.embedding.embedding import AnyEmbedding, NdArrayEmbedding
from docarray.typing.tensor.ndarray import NdArray
from docarray.typing.tensor.tensor import AnyTensor
from docarray.typing.tensor.video import VideoNdArray, VideoTensor
from docarray.typing.url import (
AnyUrl,
AudioUrl,
ImageUrl,
Mesh3DUrl,
PointCloud3DUrl,
TextUrl,
VideoUrl,
)
from docarray.utils._internal.misc import (
_get_path_from_docarray_root_level,
import_library,
)
if TYPE_CHECKING:
from docarray.typing.tensor import TensorFlowTensor # noqa: F401
from docarray.typing.tensor import ( # noqa: F401
JaxArray,
JaxArrayEmbedding,
TorchEmbedding,
TorchTensor,
)
from docarray.typing.tensor.audio import AudioJaxArray # noqa: F401
from docarray.typing.tensor.audio import AudioTensorFlowTensor # noqa: F401
from docarray.typing.tensor.audio import AudioTorchTensor # noqa: F401
from docarray.typing.tensor.embedding import TensorFlowEmbedding # noqa: F401
from docarray.typing.tensor.image import ImageJaxArray # noqa: F401
from docarray.typing.tensor.image import ImageTensorFlowTensor # noqa: F401
from docarray.typing.tensor.image import ImageTorchTensor # noqa: F401
from docarray.typing.tensor.video import VideoJaxArray # noqa: F401
from docarray.typing.tensor.video import VideoTensorFlowTensor # noqa: F401
from docarray.typing.tensor.video import VideoTorchTensor # noqa: F401
__all__ = [
'NdArray',
'NdArrayEmbedding',
'AudioNdArray',
'VideoNdArray',
'AnyEmbedding',
'ImageUrl',
'AudioUrl',
'TextUrl',
'Mesh3DUrl',
'PointCloud3DUrl',
'VideoUrl',
'AnyUrl',
'ID',
'AnyTensor',
'ImageTensor',
'AudioTensor',
'VideoTensor',
'ImageNdArray',
'ImageBytes',
'VideoBytes',
'AudioBytes',
]
_torch_tensors = [
'TorchTensor',
'TorchEmbedding',
'ImageTorchTensor',
'AudioTorchTensor',
'VideoTorchTensor',
]
_tf_tensors = [
'TensorFlowTensor',
'TensorFlowEmbedding',
'ImageTensorFlowTensor',
'AudioTensorFlowTensor',
'VideoTensorFlowTensor',
]
_jax_tensors = [
'JaxArray',
'JaxArrayEmbedding',
'VideoJaxArray',
'AudioJaxArray',
'ImageJaxArray',
]
__all_test__ = __all__ + _torch_tensors
def __getattr__(name: str):
if name in _torch_tensors:
import_library('torch', raise_error=True)
elif name in _tf_tensors:
import_library('tensorflow', raise_error=True)
elif name in _jax_tensors:
import_library('jax', raise_error=True)
else:
raise ImportError(
f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\''
)
import docarray.typing.tensor
tensor_cls = getattr(docarray.typing.tensor, name)
if name not in __all__:
__all__.append(name)
return tensor_cls