-
Notifications
You must be signed in to change notification settings - Fork 237
Expand file tree
/
Copy pathtensor_display.py
More file actions
89 lines (74 loc) · 3.37 KB
/
tensor_display.py
File metadata and controls
89 lines (74 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Licensed to the LF AI & Data foundation under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing_extensions import TYPE_CHECKING
if TYPE_CHECKING:
from rich.console import Console, ConsoleOptions, RenderResult
from rich.measure import Measurement
from docarray.typing.tensor.abstract_tensor import AbstractTensor
class TensorDisplay:
"""
Rich representation of a tensor.
"""
tensor_min_width: int = 30
def __init__(self, tensor: 'AbstractTensor'):
self.tensor = tensor
def __rich_console__(
self, console: 'Console', options: 'ConsoleOptions'
) -> 'RenderResult':
comp_be = self.tensor.get_comp_backend()
t_squeezed = comp_be.squeeze(comp_be.detach(self.tensor))
if comp_be.n_dim(t_squeezed) == 1 and comp_be.shape(t_squeezed)[0] < 200:
import colorsys
from rich.color import Color
from rich.segment import Segment
from rich.style import Style
tensor_normalized = comp_be.minmax_normalize(t_squeezed, (0, 5))
hue = 0.75
saturation = 1.0
for idx, y in enumerate(tensor_normalized):
luminance = 0.1 + ((y / 5) * 0.7)
r, g, b = colorsys.hls_to_rgb(hue, luminance + 0.07, saturation)
color = Color.from_rgb(r * 255, g * 255, b * 255)
yield Segment('▄', Style(color=color, bgcolor=color))
if idx != 0 and idx % options.max_width == 0:
yield Segment.line()
else:
from rich.text import Text
yield Text(
f'{self.tensor.__class__.__name__} of '
f'shape {comp_be.shape(self.tensor)}, '
f'dtype: {str(comp_be.dtype(self.tensor))}'
)
def __rich_measure__(
self, console: 'Console', options: 'ConsoleOptions'
) -> 'Measurement':
from rich.measure import Measurement
width = self._compute_table_width(max_width=options.max_width)
return Measurement(1, width)
def _compute_table_width(self, max_width: int) -> int:
"""
Compute the width of the table. Depending on the length of the tensor, the width
should be in the range of 30 (min) and a given `max_width`.
:return: the width of the table
"""
comp_be = self.tensor.get_comp_backend()
t_squeezed = comp_be.squeeze(comp_be.detach(self.tensor))
if comp_be.n_dim(t_squeezed) == 1 and comp_be.shape(t_squeezed)[0] < max_width:
min_capped = max(comp_be.shape(t_squeezed)[0], self.tensor_min_width)
min_max_capped = min(min_capped, max_width)
return min_max_capped
else:
return max_width