-
Notifications
You must be signed in to change notification settings - Fork 237
Expand file tree
/
Copy pathplot.py
More file actions
525 lines (437 loc) · 20 KB
/
plot.py
File metadata and controls
525 lines (437 loc) · 20 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
import copy
import json
import os.path
import tempfile
import threading
import time
import warnings
from collections import Counter
from math import sqrt, ceil, floor
from typing import Optional, Tuple, List
import numpy as np
from docarray.helper import _get_array_info
class PlotMixin:
"""Helper functions for plotting the arrays."""
def _ipython_display_(self):
"""Displays the object in IPython as a side effect"""
self.summary()
def summary(self):
"""Print the structure and attribute summary of this DocumentArray object.
.. warning::
Calling {meth}`.summary` on large DocumentArray can be slow.
"""
from rich.table import Table
from rich.console import Console
from rich.panel import Panel
import rich.markup
from rich import box
tables = []
console = Console()
(
is_homo,
_nested_in,
_nested_items,
attr_counter,
all_attrs_names,
) = _get_array_info(self)
table = Table(box=box.SIMPLE, highlight=True)
table.show_header = False
table.add_row('Type', self.__class__.__name__)
table.add_row('Length', str(len(self)))
table.add_row('Homogenous Documents', str(is_homo))
if _nested_in:
table.add_row('Has nested Documents in', str(tuple(_nested_in)))
if is_homo:
table.add_row('Common Attributes', str(list(attr_counter.items())[0][0]))
for item in _nested_items:
table.add_row(item['name'], item['value'])
is_multimodal = all(d.is_multimodal for d in self)
table.add_row('Multimodal dataclass', str(is_multimodal))
if getattr(self, '_subindices', None):
table.add_row(
'Subindices', rich.markup.escape(str(tuple(self._subindices.keys())))
)
tables.append(Panel(table, title='Documents Summary', expand=False))
all_attrs_names = tuple(sorted(all_attrs_names))
if all_attrs_names:
attr_table = Table(box=box.SIMPLE, highlight=True)
attr_table.add_column('Attribute')
attr_table.add_column('Data type')
attr_table.add_column('#Unique values')
attr_table.add_column('Has empty value')
for _a_name in all_attrs_names:
try:
_a = [getattr(d, _a_name) for d in self]
_a = set(_a)
except:
pass # intentional ignore as some fields are not hashable
_set_type_a = set(type(_aa).__name__ for _aa in _a)
attr_table.add_row(
_a_name,
str(tuple(_set_type_a)),
str(len(_a)),
str(any(_aa is None for _aa in _a)),
)
tables.append(Panel(attr_table, title='Attributes Summary', expand=False))
storage_infos = self._get_storage_infos()
if storage_infos:
storage_table = Table(box=box.SIMPLE, highlight=True)
storage_table.show_header = False
for k, v in storage_infos.items():
storage_table.add_row(k, v)
tables.append(
Panel(
storage_table,
title=f'[bold]{self.__class__.__name__}[/bold] Config',
expand=False,
)
)
console.print(*tables)
def plot_embeddings(
self,
title: str = 'MyDocumentArray',
path: Optional[str] = None,
image_sprites: bool = False,
min_image_size: int = 16,
channel_axis: int = -1,
start_server: bool = True,
host: str = '127.0.0.1',
port: Optional[int] = None,
image_source: str = 'tensor',
exclude_fields_metas: Optional[List[str]] = None,
) -> str:
"""Interactively visualize :attr:`.embeddings` using the Embedding Projector and store the visualization informations.
:param title: the title of this visualization. If you want to compare multiple embeddings at the same time,
make sure to give different names each time and set ``path`` to the same value.
:param host: if set, bind the embedding-projector frontend to given host. Otherwise `localhost` is used.
:param port: if set, run the embedding-projector frontend at given port. Otherwise a random port is used.
:param image_sprites: if set, visualize the dots using :attr:`.uri` and :attr:`.tensor`.
:param path: if set, then append the visualization to an existing folder, where you can compare multiple
embeddings at the same time. Make sure to use a different ``title`` each time .
:param min_image_size: only used when `image_sprites=True`. the minimum size of the image
:param channel_axis: only used when `image_sprites=True`. the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param start_server: if set, start a HTTP server and open the frontend directly. Otherwise, you need to rely on ``return`` path and serve by yourself.
:param image_source: specify where the image comes from, can be ``uri`` or ``tensor``. empty tensor will fallback to uri
:param exclude_fields_metas: specify the fields that you want to exclude from metadata tsv file
:return: the path to the embeddings visualization info.
"""
from docarray.helper import random_port, __resources_path__
path = path or tempfile.mkdtemp()
emb_fn = f'{title}.tsv'
meta_fn = f'{title}.metas.tsv'
config_fn = f'config.json'
sprite_fn = f'{title}.png'
if image_sprites:
img_per_row = ceil(sqrt(len(self)))
canvas_size = min(img_per_row * min_image_size, 8192)
img_size = max(int(canvas_size / img_per_row), min_image_size)
max_docs = ceil(canvas_size / img_size) ** 2
if len(self) > max_docs:
warnings.warn(
f'''
{self!r} has more than {max_docs} elements, which is the maximum number of image sprites can support.
The resulting visualization may not be correct. You can do the following:
- use fewer images: `da[:10000].plot_embeddings()`
- reduce the `min_image_size` to a smaller number, say 8 or 4 (but bear in mind you can hardly recognize anything with a 4x4 image)
- turn off `image_sprites` via `da.plot_embeddings(image_sprites=False)`
'''
)
self.plot_image_sprites(
os.path.join(path, sprite_fn),
canvas_size=canvas_size,
min_size=min_image_size,
channel_axis=channel_axis,
image_source=image_source,
)
self.save_embeddings_csv(os.path.join(path, emb_fn), delimiter='\t')
_exclude_fields_metas = ['embedding', 'tensor', 'scores'] + (
exclude_fields_metas or []
)
with_header = True
if (
len(set(self[0].non_empty_fields).difference(set(_exclude_fields_metas)))
<= 1
):
with_header = False
self.save_csv(
os.path.join(path, meta_fn),
exclude_fields=_exclude_fields_metas,
dialect='excel-tab',
with_header=with_header,
)
_epj_config = {
'embeddings': [
{
'tensorName': title,
'tensorShape': list(self.embeddings.shape),
'tensorPath': f'/static/{emb_fn}',
'metadataPath': f'/static/{meta_fn}',
'sprite': {
'imagePath': f'/static/{sprite_fn}',
'singleImageDim': (img_size,) * 2,
}
if image_sprites
else {},
}
]
}
if os.path.exists(os.path.join(path, config_fn)):
with open(os.path.join(path, config_fn)) as fp:
old_config = json.load(fp)
_epj_config['embeddings'].extend(old_config.get('embeddings', []))
with open(os.path.join(path, config_fn), 'w') as fp:
json.dump(_epj_config, fp)
import gzip
with gzip.open(
os.path.join(__resources_path__, 'embedding-projector/index.html.gz'), 'rt'
) as fr, open(os.path.join(path, 'index.html'), 'w') as fp:
fp.write(fr.read())
if start_server:
def _get_fastapi_app():
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from starlette.staticfiles import StaticFiles
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
app.mount('/static', StaticFiles(directory=path), name='static')
return app
import uvicorn
app = _get_fastapi_app()
port = port or random_port()
t_m = threading.Thread(
target=uvicorn.run,
kwargs=dict(app=app, host=host, port=port, log_level='error'),
daemon=True,
)
url_html_path = f'http://{host}:{port}/static/index.html?config={config_fn}'
t_m.start()
try:
_env = str(get_ipython()) # noqa
if 'ZMQInteractiveShell' in _env:
_env = 'jupyter'
elif 'google.colab' in _env:
_env = 'colab'
except:
_env = 'local'
if _env == 'jupyter':
time.sleep(
1
) # jitter is required otherwise encouter werid `strict-origin-when-cross-origin` error in browser
from IPython.display import IFrame, display # noqa
display(IFrame(src=url_html_path, width="100%", height=600))
warnings.warn(
f'Showing iframe in cell, you may want to open {url_html_path} in a new tab for better experience. '
f'Also, `localhost` may need to be changed to the IP address if your jupyter is running remotely. '
f'Click "stop" button in the toolbar to move to the next cell.'
)
elif _env == 'colab':
from google.colab.output import eval_js # noqa
colab_url = eval_js(f'google.colab.kernel.proxyPort({port})')
colab_url += f'/static/index.html?config={config_fn}'
warnings.warn(
f'Showing iframe in cell, you may want to open {colab_url} in a new tab for better experience. '
f'Click "stop" button in the toolbar to move to the next cell.'
)
time.sleep(
1
) # jitter is required otherwise encouter werid `strict-origin-when-cross-origin` error in browser
from IPython.display import IFrame, display
display(IFrame(src=colab_url, width="100%", height=600))
elif _env == 'local':
try:
import webbrowser
webbrowser.open(url_html_path, new=2)
except:
pass # intentional pass, browser support isn't cross-platform
finally:
print(
f'You should see a webpage opened in your browser, '
f'if not, you may open {url_html_path} manually'
)
t_m.join()
return path
def save_gif(
self,
output: str,
channel_axis: int = -1,
duration: int = 200,
size_ratio: float = 1.0,
inline_display: bool = False,
image_source: str = 'tensor',
skip_empty: bool = False,
show_index: bool = False,
show_progress: bool = False,
) -> None:
"""
Save a gif of the DocumentArray. Each frame corresponds to a Document.uri/.tensor in the DocumentArray.
:param output: the file path to save the gif to.
:param channel_axis: the color channel axis of the tensor.
:param duration: the duration of each frame in milliseconds.
:param size_ratio: the size ratio of each frame.
:param inline_display: if to show the gif in Jupyter notebook.
:param image_source: the source of the image in Document atribute.
:param skip_empty: if to skip empty documents.
:param show_index: if to show the index of the document in the top-right corner.
:param show_progress: if to show a progress bar.
:return:
"""
from rich.progress import track
from PIL import Image, ImageDraw
def img_iterator(channel_axis):
for _idx, d in enumerate(
track(self, description='Plotting', disable=not show_progress)
):
if not d.uri and d.tensor is None:
if skip_empty:
continue
else:
raise ValueError(
f'Document has neither `uri` nor `tensor`, can not be plotted'
)
_d = copy.deepcopy(d)
if image_source == 'uri' or (
image_source == 'tensor' and _d.content_type != 'tensor'
):
_d.load_uri_to_image_tensor()
channel_axis = -1
elif image_source not in ('uri', 'tensor'):
raise ValueError(f'image_source can be only `uri` or `tensor`')
_d.set_image_tensor_channel_axis(channel_axis, -1)
if size_ratio < 1:
img_size_h, img_size_w, _ = _d.tensor.shape
_d.set_image_tensor_shape(
shape=(
int(size_ratio * img_size_h),
int(size_ratio * img_size_w),
)
)
if show_index:
_img = Image.fromarray(_d.tensor)
draw = ImageDraw.Draw(_img)
draw.text((0, 0), str(_idx), (255, 255, 255))
_d.tensor = np.asarray(_img)
yield Image.fromarray(_d.tensor).convert('RGB')
imgs = img_iterator(channel_axis)
img = next(imgs) # extract first image from iterator
with open(output, 'wb') as fp:
img.save(
fp=fp,
format='GIF',
append_images=imgs,
save_all=True,
duration=duration,
loop=0,
)
if inline_display:
from IPython.display import Image, display
display(Image(output))
def plot_image_sprites(
self,
output: Optional[str] = None,
canvas_size: int = 512,
min_size: int = 16,
channel_axis: int = -1,
image_source: str = 'tensor',
skip_empty: bool = False,
show_progress: bool = False,
show_index: bool = False,
fig_size: Optional[Tuple[int, int]] = (10, 10),
keep_aspect_ratio: bool = False,
) -> None:
"""Generate a sprite image for all image tensors in this DocumentArray-like object.
An image sprite is a collection of images put into a single image. It is always square-sized.
Each sub-image is also square-sized and equally-sized.
:param output: Optional path to store the visualization. If not given, show in UI
:param canvas_size: the size of the canvas
:param min_size: the minimum size of the image
:param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param image_source: specify where the image comes from, can be ``uri`` or ``tensor``. empty tensor will fallback to uri
:param skip_empty: skip Document who has no .uri or .tensor.
:param show_index: show the index on the top-right corner of every image
:param fig_size: the size of the figure
:param show_progress: show a progressbar while plotting.
:param keep_aspect_ratio: preserve the aspect ratio of the image by using the aspect ratio of the first image in self.
"""
if not self:
raise ValueError(f'{self!r} is empty')
import matplotlib.pyplot as plt
img_per_row = ceil(sqrt(len(self)))
img_size = int(canvas_size / img_per_row)
if img_size < min_size:
# image is too small, recompute the size
img_size = min_size
img_per_row = int(canvas_size / img_size)
if img_per_row == 0:
img_per_row = 1
img_per_col = ceil(len(self) / img_per_row)
max_num_img = img_per_row * img_per_col
sprite_img = np.zeros(
[img_size * img_per_col, img_size * img_per_row, 3], dtype='uint8'
)
img_size_w, img_size_h = img_size, img_size
set_aspect_ratio = False
from rich.progress import track
from PIL import Image, ImageDraw
try:
for _idx, d in enumerate(
track(self, description='Plotting', disable=not show_progress)
):
if not d.uri and d.tensor is None:
if skip_empty:
continue
else:
raise ValueError(
f'Document has neither `uri` nor `tensor`, can not be plotted'
)
_d = copy.deepcopy(d)
if image_source == 'uri' or (
image_source == 'tensor' and _d.content_type != 'tensor'
):
_d.load_uri_to_image_tensor()
channel_axis = -1
elif image_source not in ('uri', 'tensor'):
raise ValueError(f'image_source can be only `uri` or `tensor`')
_d.set_image_tensor_channel_axis(channel_axis, -1)
if keep_aspect_ratio and not set_aspect_ratio:
h, w, _ = _d.tensor.shape
img_size_h = int(h * img_size / w)
sprite_img = np.zeros(
[img_size_h * img_per_col, img_size_w * img_per_row, 3],
dtype='uint8',
)
set_aspect_ratio = True
_d.set_image_tensor_shape(shape=(img_size_h, img_size_w))
row_id = floor(_idx / img_per_row)
col_id = _idx % img_per_row
if show_index:
_img = Image.fromarray(np.asarray(_d.tensor, dtype='uint8'))
draw = ImageDraw.Draw(_img)
draw.text((0, 0), str(_idx), (255, 255, 255))
_d.tensor = np.asarray(_img)
sprite_img[
(row_id * img_size_h) : ((row_id + 1) * img_size_h),
(col_id * img_size_w) : ((col_id + 1) * img_size_w),
] = _d.tensor
except Exception as ex:
raise ValueError(
'Bad image tensor. Try different `image_source` or `channel_axis`'
) from ex
im = Image.fromarray(sprite_img)
if output:
with open(output, 'wb') as fp:
im.save(fp)
else:
plt.figure(figsize=fig_size, frameon=False)
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(im)
plt.show()