-
Notifications
You must be signed in to change notification settings - Fork 237
Expand file tree
/
Copy pathdocument_array_summary.py
More file actions
76 lines (63 loc) · 2.77 KB
/
document_array_summary.py
File metadata and controls
76 lines (63 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import TYPE_CHECKING, List
from docarray.typing.tensor.abstract_tensor import AbstractTensor
if TYPE_CHECKING:
from docarray.array import DocVec
from docarray.array.any_array import AnyDocArray
class DocArraySummary:
def __init__(self, docs: 'AnyDocArray'):
self.docs = docs
def summary(self) -> None:
"""
Print a summary of this DocList object and a summary of the schema of its
Document type.
"""
from rich import box
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from docarray.array import DocVec
table = Table(box=box.SIMPLE, highlight=True)
table.show_header = False
table.add_row('Type', self.docs.__class__.__name__)
table.add_row('Length', str(len(self.docs)), end_section=True)
if isinstance(self.docs, DocVec):
table.add_row('Stacked columns:')
stacked_fields = self._get_stacked_fields(docs=self.docs)
for field_name in stacked_fields:
val = self.docs
for attr in field_name.split('.'):
val = getattr(val, attr)
if isinstance(val, AbstractTensor):
comp_be = val.get_comp_backend()
if comp_be.to_numpy(comp_be.isnan(val)).all():
col_2 = f'None ({val.__class__.__name__})'
else:
col_2 = (
f'{val.__class__.__name__} of shape {comp_be.shape(val)}'
f', dtype: {comp_be.dtype(val)}'
)
if comp_be.device(val):
col_2 += f', device: {comp_be.device(val)}'
table.add_row(f' • {field_name}:', col_2)
Console().print(Panel(table, title='DocList Summary', expand=False))
self.docs.doc_type.schema_summary()
@staticmethod
def _get_stacked_fields(docs: 'DocVec') -> List[str]: # TODO this might
# broken
"""
Return a list of the field names of a DocVec instance that are
doc_vec, i.e. all the fields that are of type AbstractTensor. Nested field
paths are separated by dot, such as: 'attr.nested_attr'.
"""
fields = []
for field_name, value_tens in docs._storage.tensor_columns.items():
fields.append(field_name)
for field_name, value_doc in docs._storage.doc_columns.items():
if value_doc is not None:
fields.extend(
[
f'{field_name}.{x}'
for x in DocArraySummary._get_stacked_fields(docs=value_doc)
]
)
return fields