-
Notifications
You must be signed in to change notification settings - Fork 244
Expand file tree
/
Copy pathtest_tensor.py
More file actions
45 lines (32 loc) · 1.22 KB
/
Copy pathtest_tensor.py
File metadata and controls
45 lines (32 loc) · 1.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import pytest
import torch
from docarray import BaseDoc
from docarray.typing import AnyTensor, NdArray, TorchTensor
from docarray.utils._internal.misc import is_tf_available
tf_available = is_tf_available()
if tf_available:
import tensorflow as tf
import tensorflow._api.v2.experimental.numpy as tnp # type: ignore
from docarray.typing import TensorFlowTensor
else:
TensorFlowTensor = None
def test_set_tensor():
class MyDocument(BaseDoc):
tensor: AnyTensor
d = MyDocument(tensor=np.zeros((3, 224, 224)))
assert isinstance(d.tensor, NdArray)
assert isinstance(d.tensor, np.ndarray)
assert (d.tensor == np.zeros((3, 224, 224))).all()
d = MyDocument(tensor=torch.zeros((3, 224, 224)))
assert isinstance(d.tensor, TorchTensor)
assert isinstance(d.tensor, torch.Tensor)
assert (d.tensor == torch.zeros((3, 224, 224))).all()
@pytest.mark.tensorflow
def test_set_tensor_tensorflow():
class MyDocument(BaseDoc):
tensor: AnyTensor
d = MyDocument(tensor=tf.zeros((3, 224, 224)))
assert isinstance(d.tensor, TensorFlowTensor)
assert isinstance(d.tensor.tensor, tf.Tensor)
assert tnp.allclose(d.tensor.tensor, tf.zeros((3, 224, 224)))