-
Notifications
You must be signed in to change notification settings - Fork 244
Expand file tree
/
Copy pathtest_sample.py
More file actions
124 lines (115 loc) · 4.5 KB
/
Copy pathtest_sample.py
File metadata and controls
124 lines (115 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import pytest
from docarray import DocumentArray
from docarray.array.opensearch import DocumentArrayOpenSearch
from docarray.array.qdrant import DocumentArrayQdrant
from docarray.array.sqlite import DocumentArraySqlite
from docarray.array.annlite import DocumentArrayAnnlite, AnnliteConfig
from docarray.array.storage.opensearch import OpenSearchConfig
from docarray.array.storage.qdrant import QdrantConfig
from docarray.array.storage.weaviate import WeaviateConfig
from docarray.array.weaviate import DocumentArrayWeaviate
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.redis import DocumentArrayRedis, RedisConfig
from docarray.array.milvus import DocumentArrayMilvus, MilvusConfig
@pytest.mark.parametrize(
'da_cls,config',
[
(DocumentArray, None),
(DocumentArraySqlite, None),
(DocumentArrayAnnlite, AnnliteConfig(n_dim=128)),
(DocumentArrayWeaviate, WeaviateConfig(n_dim=128)),
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayOpenSearch, OpenSearchConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
(DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_sample(da_cls, config, start_storage):
if config:
da = da_cls.empty(100, config=config)
else:
da = da_cls.empty(100)
sampled = da.sample(1)
assert len(sampled) == 1
sampled = da.sample(5)
assert len(sampled) == 5
assert isinstance(sampled, DocumentArray)
with pytest.raises(ValueError):
da.sample(101) # can not sample with k greater than lenth of document array.
@pytest.mark.parametrize(
'da_cls,config',
[
(DocumentArray, None),
(DocumentArraySqlite, None),
(DocumentArrayAnnlite, AnnliteConfig(n_dim=128)),
(DocumentArrayWeaviate, WeaviateConfig(n_dim=128)),
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayOpenSearch, OpenSearchConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
(DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_sample_with_seed(da_cls, config, start_storage):
if config:
da = da_cls.empty(100, config=config)
else:
da = da_cls.empty(100)
sampled_1 = da.sample(5, seed=1)
sampled_2 = da.sample(5, seed=1)
sampled_3 = da.sample(5, seed=2)
assert len(sampled_1) == len(sampled_2) == len(sampled_3) == 5
assert sampled_1 == sampled_2
assert sampled_1 != sampled_3
@pytest.mark.parametrize(
'da_cls,config',
[
(DocumentArray, None),
(DocumentArraySqlite, None),
(DocumentArrayAnnlite, AnnliteConfig(n_dim=128)),
(DocumentArrayWeaviate, WeaviateConfig(n_dim=128)),
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayOpenSearch, OpenSearchConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
(DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_shuffle(da_cls, config, start_storage):
if config:
da = da_cls.empty(100, config=config)
else:
da = da_cls.empty(100)
shuffled = da.shuffle()
assert len(shuffled) == len(da)
assert isinstance(shuffled, DocumentArray)
ids_before_shuffle = [d.id for d in da]
ids_after_shuffle = [d.id for d in shuffled]
assert ids_before_shuffle != ids_after_shuffle
assert sorted(ids_before_shuffle) == sorted(ids_after_shuffle)
@pytest.mark.parametrize(
'da_cls,config',
[
(DocumentArray, None),
(DocumentArraySqlite, None),
(DocumentArrayAnnlite, AnnliteConfig(n_dim=128)),
(DocumentArrayWeaviate, WeaviateConfig(n_dim=128)),
(DocumentArrayQdrant, QdrantConfig(n_dim=128)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
(DocumentArrayOpenSearch, OpenSearchConfig(n_dim=128)),
(DocumentArrayRedis, RedisConfig(n_dim=128)),
(DocumentArrayMilvus, MilvusConfig(n_dim=128)),
],
)
def test_shuffle_with_seed(da_cls, config, start_storage):
if config:
da = da_cls.empty(100, config=config)
else:
da = da_cls.empty(100)
shuffled_1 = da.shuffle(seed=1)
shuffled_2 = da.shuffle(seed=1)
shuffled_3 = da.shuffle(seed=2)
assert len(shuffled_1) == len(shuffled_2) == len(shuffled_3) == len(da)
assert shuffled_1 == shuffled_2
assert shuffled_1 != shuffled_3