Skip to content

Commit 9e00a00

Browse files
committed
feat: push meta data along with docarray
1 parent 42bf943 commit 9e00a00

File tree

3 files changed

+79
-10
lines changed

3 files changed

+79
-10
lines changed

docarray/array/mixins/io/pushpull.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,82 @@
11
import os
22
import warnings
33
from pathlib import Path
4-
from typing import Dict, Type, TYPE_CHECKING, Optional
4+
from typing import Dict, Type, TYPE_CHECKING, List, Optional
5+
6+
import hubble
7+
from hubble import Client
8+
from hubble.client.endpoints import EndpointsV2
59

610
from docarray.helper import get_request_header, __cache_path__
711

812
if TYPE_CHECKING:
913
from docarray.typing import T
1014

1115

16+
def _get_length_from_summary(summary: List[Dict]) -> Optional[int]:
17+
"""Get the length from summary."""
18+
for item in summary:
19+
if 'Length' == item['name']:
20+
return item['value']
21+
22+
1223
class PushPullMixin:
1324
"""Transmitting :class:`DocumentArray` via Jina Cloud Service"""
1425

1526
_max_bytes = 4 * 1024 * 1024 * 1024
1627

28+
@classmethod
29+
@hubble.login_required
30+
def cloud_list(cls, show_table: bool = False) -> List[str]:
31+
"""List all available arrays in the cloud.
32+
33+
:param show_table: if true, show the table of the arrays.
34+
:returns: List of available DocumentArray's names.
35+
"""
36+
37+
result = []
38+
from rich.table import Table
39+
from rich import box
40+
41+
table = Table(
42+
title='Your DocumentArray on the cloud', box=box.SIMPLE, highlight=True
43+
)
44+
table.add_column('Name')
45+
table.add_column('Length')
46+
table.add_column('Visibility')
47+
table.add_column('Create at', justify='center')
48+
table.add_column('Updated at', justify='center')
49+
50+
for da in Client(jsonify=True).list_artifacts(
51+
filter={'type': 'documentArray'}, sort={'createdAt': 1}
52+
)['data']:
53+
if da['type'] == 'documentArray':
54+
result.append(da['name'])
55+
56+
table.add_row(
57+
da['name'],
58+
str(_get_length_from_summary(da['metaData'].get('summary', []))),
59+
da['visibility'],
60+
da['createdAt'],
61+
da['updatedAt'],
62+
)
63+
64+
if show_table:
65+
from rich import print
66+
67+
print(table)
68+
return result
69+
70+
@classmethod
71+
@hubble.login_required
72+
def cloud_delete(cls, name: str) -> None:
73+
"""
74+
Delete a DocumentArray from the cloud.
75+
:param name: the name of the DocumentArray to delete.
76+
"""
77+
Client(jsonify=True).delete_artifact(name)
78+
79+
@hubble.login_required
1780
def push(
1881
self,
1982
name: str,
@@ -51,7 +114,6 @@ def push(
51114
)
52115

53116
headers = {'Content-Type': ctype, **get_request_header()}
54-
import hubble
55117

56118
auth_token = hubble.get_token()
57119
if auth_token:
@@ -98,8 +160,6 @@ def _get_chunk(_batch):
98160
yield _tail
99161

100162
with pbar:
101-
from hubble import Client
102-
from hubble.client.endpoints import EndpointsV2
103163

104164
response = requests.post(
105165
Client()._base_url + EndpointsV2.upload_artifact,
@@ -113,6 +173,7 @@ def _get_chunk(_batch):
113173
response.raise_for_status()
114174

115175
@classmethod
176+
@hubble.login_required
116177
def pull(
117178
cls: Type['T'],
118179
name: str,
@@ -133,16 +194,11 @@ def pull(
133194

134195
headers = {}
135196

136-
import hubble
137-
138197
auth_token = hubble.get_token()
139198

140199
if auth_token:
141200
headers['Authorization'] = f'token {auth_token}'
142201

143-
from hubble import Client
144-
from hubble.client.endpoints import EndpointsV2
145-
146202
url = Client()._base_url + EndpointsV2.download_artifact + f'?name={name}'
147203
response = requests.get(url, headers=headers)
148204

@@ -183,3 +239,6 @@ def pull(
183239
fp.write(_source.content)
184240

185241
return r
242+
243+
cloud_push = push
244+
cloud_pull = pull

tests/unit/array/mixins/test_io.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,16 @@ def test_push_pull_io(da_cls, config, show_progress, start_storage):
237237
assert len(da1) == len(da2) == 10
238238
assert da1.texts == da2.texts == random_texts
239239

240+
all_names = DocumentArray.cloud_list()
241+
242+
assert name in all_names
243+
244+
DocumentArray.cloud_delete(name)
245+
246+
all_names = DocumentArray.cloud_list()
247+
248+
assert name not in all_names
249+
240250

241251
@pytest.mark.parametrize(
242252
'protocol', ['protobuf', 'pickle', 'protobuf-array', 'pickle-array']

tests/unit/array/test_from_to_bytes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_from_to_safe_list(target, protocol, to_fn):
110110

111111
@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
112112
@pytest.mark.parametrize('show_progress', [True, False])
113-
def test_push_pull_show_progress(show_progress, protocol):
113+
def test_to_bytes_show_progress(show_progress, protocol):
114114
da = DocumentArray.empty(1000)
115115
r = da.to_bytes(_show_progress=show_progress, protocol=protocol)
116116
da_r = DocumentArray.from_bytes(r, _show_progress=show_progress, protocol=protocol)

0 commit comments

Comments
 (0)