forked from docarray/docarray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase.py
More file actions
133 lines (104 loc) · 4.02 KB
/
base.py
File metadata and controls
133 lines (104 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import copy as cp
from dataclasses import fields
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, Tuple, Dict
from .dataclasses import is_multimodal
from .helper import typename
if TYPE_CHECKING:
from .typing import T
@lru_cache()
def _get_fields(dc):
return [f.name for f in fields(dc)]
class BaseDCType:
_data_class = None
def __init__(
self: 'T',
_obj: Optional['T'] = None,
copy: bool = False,
field_resolver: Optional[Dict[str, str]] = None,
unknown_fields_handler: str = 'catch',
**kwargs,
):
self._data = None
if isinstance(_obj, type(self)):
if copy:
self.copy_from(_obj)
else:
self._data = _obj._data
elif isinstance(_obj, dict):
kwargs.update(_obj)
elif is_multimodal(_obj):
self._data = type(self)._from_dataclass(_obj)._data
if kwargs:
try:
self._data = self._data_class(self, **kwargs)
except TypeError as ex:
if unknown_fields_handler == 'raise':
raise AttributeError(f'unknown attributes') from ex
else:
if field_resolver:
kwargs = {
field_resolver.get(k, k): v for k, v in kwargs.items()
}
_fields = _get_fields(self._data_class)
_unknown_kwargs = None
_unresolved = set(kwargs.keys()).difference(_fields)
if _unresolved:
_unknown_kwargs = {k: kwargs[k] for k in _unresolved}
for k in _unresolved:
kwargs.pop(k)
self._data = self._data_class(self, **kwargs)
if _unknown_kwargs and unknown_fields_handler == 'catch':
getattr(self, self._unresolved_fields_dest).update(
_unknown_kwargs
)
for k in self._post_init_fields:
if k in kwargs:
setattr(self, k, kwargs[k])
if not _obj and not kwargs and self._data is None:
self._data = self._data_class(self)
if self._data is None:
raise ValueError(
f'Failed to initialize {typename(self)} from obj={_obj}, kwargs={kwargs}'
)
def copy_from(self: 'T', other: 'T') -> None:
"""Overwrite self by copying from another :class:`Document`.
:param other: the other Document to copy from
"""
self._data = cp.deepcopy(other._data)
def clear(self) -> None:
"""Clear all fields from this :class:`Document` to their default values."""
for f in self.non_empty_fields:
setattr(self._data, f, None)
def pop(self, *fields) -> None:
"""Clear some fields from this :class:`Document` to their default values.
:param fields: field names to clear.
"""
for f in fields:
if hasattr(self, f):
setattr(self._data, f, None)
@property
def non_empty_fields(self) -> Tuple[str]:
"""Get all non-emtpy fields of this :class:`Document`.
Non-empty fields are the fields with not-`None` and not-default values.
:return: field names in a tuple.
"""
return self._data._non_empty_fields
@property
def nbytes(self) -> int:
"""Return total bytes consumed by protobuf.
:return: number of bytes
"""
return len(bytes(self))
def __hash__(self):
return hash(self._data)
def __repr__(self):
content = str(self.non_empty_fields)
content += f' at {getattr(self, "id", id(self))}'
return f'<{self.__class__.__name__} {content.strip()}>'
def __bytes__(self):
return self.to_bytes()
def __eq__(self, other):
if type(self) is type(other):
return self._data == other._data
return False