Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from collections import Counter
from math import sqrt, ceil, floor
from typing import Optional
from typing import Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -317,6 +317,8 @@ def plot_image_sprites(
image_source: str = 'tensor',
skip_empty: bool = False,
show_progress: bool = False,
show_index: bool = False,
fig_size: Optional[Tuple[int, int]] = None,
) -> None:
"""Generate a sprite image for all image tensors in this DocumentArray-like object.

Expand Down Expand Up @@ -351,9 +353,12 @@ def plot_image_sprites(
img_id = 0

from rich.progress import track
from PIL import Image, ImageDraw

try:
for d in track(self, description='Plotting', disable=not show_progress):
for _idx, d in enumerate(
track(self, description='Plotting', disable=not show_progress)
):

if not d.uri and d.tensor is None:
if skip_empty:
Expand All @@ -379,6 +384,13 @@ def plot_image_sprites(

row_id = floor(img_id / img_per_row)
col_id = img_id % img_per_row

if show_index:
_img = Image.fromarray(_d.tensor)
draw = ImageDraw.Draw(_img)
draw.text((0, 0), str(_idx), (255, 255, 255))
_d.tensor = np.asarray(_img)

sprite_img[
(row_id * img_size) : ((row_id + 1) * img_size),
(col_id * img_size) : ((col_id + 1) * img_size),
Expand All @@ -392,14 +404,14 @@ def plot_image_sprites(
'Bad image tensor. Try different `image_source` or `channel_axis`'
) from ex

from PIL import Image

im = Image.fromarray(sprite_img)

if output:
with open(output, 'wb') as fp:
im.save(fp)
else:
if fig_size:
plt.figure(figsize=fig_size, frameon=False)
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
Expand Down
2 changes: 0 additions & 2 deletions docarray/document/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def display(self):
else:
self.summary()

plot = deprecate_by(display, removed_at='0.5')

def plot_matches_sprites(
self,
top_k: int = 10,
Expand Down