forked from docarray/docarray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpushpull.py
More file actions
155 lines (127 loc) · 5.16 KB
/
pushpull.py
File metadata and controls
155 lines (127 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import io
from contextlib import nullcontext
from typing import Type, TYPE_CHECKING, Optional
from ....helper import get_request_header
if TYPE_CHECKING:
from ....types import T
class PushPullMixin:
"""Transmitting :class:`DocumentArray` via Jina Cloud Service"""
_service_url = 'https://apihubble.jina.ai/v2/rpc/da.'
_max_bytes = 4 * 1024 * 1024 * 1024
def push(self, token: str, show_progress: bool = False) -> None:
"""Push this DocumentArray object to Jina Cloud which can be later retrieved via :meth:`.push`
.. note::
- Push with the same ``token`` will override the existing content.
- Kinda like a public clipboard where everyone can override anyone's content.
So to make your content survive longer, you may want to use longer & more complicated token.
- The lifetime of the content is not promised atm, could be a day, could be a week. Do not use it for
persistence. Only use this full temporary transmission/storage/clipboard.
:param token: a key that later can be used for retrieve this :class:`DocumentArray`.
:param show_progress: if to show a progress bar on pulling
"""
import requests
dict_data = self._get_dict_data(token, show_progress)
progress = _get_progressbar(show_progress)
task_id = progress.add_task('upload', start=False) if show_progress else None
class BufferReader(io.BytesIO):
def __init__(self, buf=b'', p_bar=None, task_id=None):
super().__init__(buf)
self._len = len(buf)
self._p_bar = p_bar
self._task_id = task_id
if show_progress:
progress.update(task_id, total=self._len)
progress.start_task(task_id)
def __len__(self):
return self._len
def read(self, n=-1):
chunk = io.BytesIO.read(self, n)
if self._p_bar:
self._p_bar.update(self._task_id, advance=len(chunk))
return chunk
(data, ctype) = requests.packages.urllib3.filepost.encode_multipart_formdata(
dict_data
)
headers = {'Content-Type': ctype, **get_request_header()}
with progress as p_bar:
body = BufferReader(data, p_bar, task_id)
requests.post(self._service_url + 'push', data=body, headers=headers)
@classmethod
def pull(
cls: Type['T'],
token: str,
show_progress: bool = False,
) -> 'T':
"""Pulling a :class:`DocumentArray` from Jina Cloud Service to local.
:param token: the upload token set during :meth:`.push`
:param show_progress: if to show a progress bar on pulling
:return: a :class:`DocumentArray` object
"""
import requests
url = f'{cls._service_url}pull?token={token}'
response = requests.get(url)
progress = _get_progressbar(show_progress)
url = response.json()['data']['download']
with requests.get(
url,
stream=True,
headers=get_request_header(),
) as r, progress:
r.raise_for_status()
if show_progress:
task_id = progress.add_task('download', start=False)
progress.update(task_id, total=int(r.headers['Content-length']))
with io.BytesIO() as f:
chunk_size = 8192
if show_progress:
progress.start_task(task_id)
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
if show_progress:
progress.update(task_id, advance=len(chunk))
if show_progress:
progress.stop()
return cls.from_bytes(
f.getvalue(),
protocol='protobuf',
compress='gzip',
_show_progress=show_progress,
)
def _get_dict_data(self, token, show_progress):
_serialized = self.to_bytes(
protocol='protobuf', compress='gzip', _show_progress=show_progress
)
if len(_serialized) > self._max_bytes:
raise ValueError(
f'DocumentArray is too big. '
f'Size of the serialization {len(_serialized)} is larger than {self._max_bytes}.'
)
return {
'file': (
'DocumentArray',
_serialized,
),
'token': token,
}
def _get_progressbar(show_progress):
if show_progress:
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TimeRemainingColumn,
TransferSpeedColumn,
)
return Progress(
BarColumn(bar_width=None),
"[progress.percentage]{task.percentage:>3.1f}%",
"•",
DownloadColumn(),
"•",
TransferSpeedColumn(),
"•",
TimeRemainingColumn(),
transient=True,
)
else:
return nullcontext()