forked from docarray/docarray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
311 lines (261 loc) · 11.4 KB
/
plot.py
File metadata and controls
311 lines (261 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import copy
import json
import os.path
import tempfile
import threading
import warnings
from collections import Counter
from math import sqrt, ceil, floor
from typing import Optional
import numpy as np
class PlotMixin:
"""Helper functions for plotting the arrays. """
def summary(self):
"""Print the structure and attribute summary of this DocumentArray object.
.. warning::
Calling {meth}`.summary` on large DocumentArray can be slow.
"""
from rich.table import Table
from rich.console import Console
from rich import box
all_attrs = self.get_attributes('non_empty_fields')
attr_counter = Counter(all_attrs)
table = Table(box=box.SIMPLE, title='Documents Summary')
table.show_header = False
table.add_row('Length', str(len(self)))
is_homo = len(attr_counter) == 1
table.add_row('Homogenous Documents', str(is_homo))
all_attrs_names = set(v for k in all_attrs for v in k)
_nested_in = []
if 'chunks' in all_attrs_names:
_nested_in.append('chunks')
if 'matches' in all_attrs_names:
_nested_in.append('matches')
if _nested_in:
table.add_row('Has nested Documents in', str(tuple(_nested_in)))
if is_homo:
table.add_row('Common Attributes', str(list(attr_counter.items())[0][0]))
else:
for _a, _n in attr_counter.most_common():
if _n <= 1:
_doc_text = f'{_n} Document has'
else:
_doc_text = f'{_n} Documents have'
if len(_a) == 1:
_text = f'{_doc_text} one attribute'
elif len(_a) == 0:
_text = f'{_doc_text} no attribute'
else:
_text = f'{_doc_text} attributes'
table.add_row(_text, str(_a))
console = Console()
all_attrs_names = tuple(sorted(all_attrs_names))
if not all_attrs_names:
console.print(table)
return
attr_table = Table(box=box.SIMPLE, title='Attributes Summary')
attr_table.add_column('Attribute')
attr_table.add_column('Data type')
attr_table.add_column('#Unique values')
attr_table.add_column('Has empty value')
all_attrs_values = self.get_attributes(*all_attrs_names)
if len(all_attrs_names) == 1:
all_attrs_values = [all_attrs_values]
for _a, _a_name in zip(all_attrs_values, all_attrs_names):
try:
_a = set(_a)
except:
pass # intentional ignore as some fields are not hashable
_set_type_a = set(type(_aa).__name__ for _aa in _a)
attr_table.add_row(
_a_name,
str(tuple(_set_type_a)),
str(len(_a)),
str(any(_aa is None for _aa in _a)),
)
console.print(table, attr_table)
def plot_embeddings(
self,
title: str = 'MyDocumentArray',
path: Optional[str] = None,
image_sprites: bool = False,
min_image_size: int = 16,
channel_axis: int = -1,
start_server: bool = True,
port: Optional[int] = None,
) -> str:
"""Interactively visualize :attr:`.embeddings` using the Embedding Projector.
:param title: the title of this visualization. If you want to compare multiple embeddings at the same time,
make sure to give different names each time and set ``path`` to the same value.
:param port: if set, run the embedding-projector frontend at given port. Otherwise a random port is used.
:param image_sprites: if set, visualize the dots using :attr:`.uri` and :attr:`.blob`.
:param path: if set, then append the visualization to an existing folder, where you can compare multiple
embeddings at the same time. Make sure to use a different ``title`` each time .
:param min_image_size: only used when `image_sprites=True`. the minimum size of the image
:param channel_axis: only used when `image_sprites=True`. the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param start_server: if set, start a HTTP server and open the frontend directly. Otherwise, you need to rely on ``return`` path and serve by yourself.
:return: the path to the embeddings visualization info.
"""
from ...helper import random_port, __resources_path__
path = path or tempfile.mkdtemp()
emb_fn = f'{title}.tsv'
meta_fn = f'{title}.metas.tsv'
config_fn = f'config.json'
sprite_fn = f'{title}.png'
if image_sprites:
img_per_row = ceil(sqrt(len(self)))
canvas_size = min(img_per_row * min_image_size, 8192)
img_size = max(int(canvas_size / img_per_row), min_image_size)
max_docs = ceil(canvas_size / img_size) ** 2
if len(self) > max_docs:
warnings.warn(
f'''
{self!r} has more than {max_docs} elements, which is the maximum number of image sprites can support.
The resulting visualization may not be correct. You can do the following:
- use fewer images: `da[:10000].plot_embeddings()`
- reduce the `min_image_size` to a smaller number, say 8 or 4 (but bear in mind you can hardly recognize anything with a 4x4 image)
- turn off `image_sprites` via `da.plot_embeddings(image_sprites=False)`
'''
)
self.plot_image_sprites(
os.path.join(path, sprite_fn),
canvas_size=canvas_size,
min_size=min_image_size,
channel_axis=channel_axis,
)
self.save_embeddings_csv(os.path.join(path, emb_fn), delimiter='\t')
_exclude_fields = ('embedding', 'blob', 'scores')
with_header = True
if len(set(self[0].non_empty_fields).difference(set(_exclude_fields))) <= 1:
with_header = False
self.save_csv(
os.path.join(path, meta_fn),
exclude_fields=_exclude_fields,
dialect='excel-tab',
with_header=with_header,
)
_epj_config = {
'embeddings': [
{
'tensorName': title,
'tensorShape': list(self.embeddings.shape),
'tensorPath': f'/static/{emb_fn}',
'metadataPath': f'/static/{meta_fn}',
'sprite': {
'imagePath': f'/static/{sprite_fn}',
'singleImageDim': (img_size,) * 2,
}
if image_sprites
else {},
}
]
}
if os.path.exists(os.path.join(path, config_fn)):
with open(os.path.join(path, config_fn)) as fp:
old_config = json.load(fp)
_epj_config['embeddings'].extend(old_config.get('embeddings', []))
with open(os.path.join(path, config_fn), 'w') as fp:
json.dump(_epj_config, fp)
import gzip
with gzip.open(
os.path.join(__resources_path__, 'embedding-projector/index.html.gz'), 'rt'
) as fr, open(os.path.join(path, 'index.html'), 'w') as fp:
fp.write(fr.read())
if start_server:
def _get_fastapi_app():
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from starlette.staticfiles import StaticFiles
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
app.mount('/static', StaticFiles(directory=path), name='static')
return app
import uvicorn
app = _get_fastapi_app()
port = port or random_port()
t_m = threading.Thread(
target=uvicorn.run,
kwargs=dict(app=app, port=port, log_level='error'),
daemon=True,
)
t_m.start()
url_html_path = (
f'http://localhost:{port}/static/index.html?config={config_fn}'
)
try:
import webbrowser
webbrowser.open(url_html_path, new=2)
except:
pass # intentional pass, browser support isn't cross-platform
finally:
print(
f'You should see a webpage opened in your browser, '
f'if not, you may open {url_html_path} manually'
)
t_m.join()
return path
def plot_image_sprites(
self,
output: Optional[str] = None,
canvas_size: int = 512,
min_size: int = 16,
channel_axis: int = -1,
) -> None:
"""Generate a sprite image for all image blobs in this DocumentArray-like object.
An image sprite is a collection of images put into a single image. It is always square-sized.
Each sub-image is also square-sized and equally-sized.
:param output: Optional path to store the visualization. If not given, show in UI
:param canvas_size: the size of the canvas
:param min_size: the minimum size of the image
:param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
"""
if not self:
raise ValueError(f'{self!r} is empty')
import matplotlib.pyplot as plt
img_per_row = ceil(sqrt(len(self)))
img_size = int(canvas_size / img_per_row)
if img_size < min_size:
# image is too small, recompute the size
img_size = min_size
img_per_row = int(canvas_size / img_size)
max_num_img = img_per_row ** 2
sprite_img = np.zeros(
[img_size * img_per_row, img_size * img_per_row, 3], dtype='uint8'
)
img_id = 0
for d in self:
_d = copy.deepcopy(d)
if _d.content_type != 'blob':
_d.load_uri_to_image_blob()
channel_axis = -1
_d.set_image_blob_channel_axis(channel_axis, -1).set_image_blob_shape(
shape=(img_size, img_size)
)
row_id = floor(img_id / img_per_row)
col_id = img_id % img_per_row
sprite_img[
(row_id * img_size) : ((row_id + 1) * img_size),
(col_id * img_size) : ((col_id + 1) * img_size),
] = _d.blob
img_id += 1
if img_id >= max_num_img:
break
from PIL import Image
im = Image.fromarray(sprite_img)
if output:
with open(output, 'wb') as fp:
im.save(fp)
else:
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(im)
plt.show()