-
Notifications
You must be signed in to change notification settings - Fork 237
Expand file tree
/
Copy pathbase.py
More file actions
145 lines (116 loc) · 4.67 KB
/
base.py
File metadata and controls
145 lines (116 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import copy as cp
import dataclasses
from dataclasses import fields
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, Tuple, Dict
from docarray.dataclasses import is_multimodal
from docarray.helper import typename
if TYPE_CHECKING: # pragma: no cover
from docarray.typing import T
@lru_cache()
def _get_fields(dc):
return [f.name for f in fields(dc)]
class BaseDCType:
_data_class = None
def __init__(
self: 'T',
_obj: Optional['T'] = None,
copy: bool = False,
field_resolver: Optional[Dict[str, str]] = None,
unknown_fields_handler: str = 'catch',
**kwargs,
):
self._data = None
if isinstance(_obj, type(self)):
if copy:
self.copy_from(_obj)
else:
self._data = _obj._data
elif isinstance(_obj, dict):
kwargs.update(_obj)
elif is_multimodal(_obj):
self._data = type(self)._from_dataclass(_obj)._data
if kwargs:
try:
if self._data is not None:
if self._unresolved_fields_dest in kwargs.keys():
getattr(self, self._unresolved_fields_dest).update(
kwargs[self._unresolved_fields_dest]
)
kwargs.pop(self._unresolved_fields_dest)
self._data = dataclasses.replace(self._data, **kwargs)
else:
self._data = self._data_class(self, **kwargs)
except TypeError as ex:
if unknown_fields_handler == 'raise':
raise AttributeError(f'unknown attributes') from ex
else:
if field_resolver:
kwargs = {
field_resolver.get(k, k): v for k, v in kwargs.items()
}
_fields = _get_fields(self._data_class)
_unknown_kwargs = None
_unresolved = set(kwargs.keys()).difference(_fields)
if _unresolved:
_unknown_kwargs = {k: kwargs[k] for k in _unresolved}
for k in _unresolved:
kwargs.pop(k)
if self._data is not None:
self._data = dataclasses.replace(self._data, **kwargs)
else:
self._data = self._data_class(self, **kwargs)
if _unknown_kwargs and unknown_fields_handler == 'catch':
getattr(self, self._unresolved_fields_dest).update(
_unknown_kwargs
)
for k in self._post_init_fields:
if k in kwargs:
setattr(self, k, kwargs[k])
if not _obj and not kwargs and self._data is None:
self._data = self._data_class(self)
if self._data is None:
raise ValueError(
f'Failed to initialize {typename(self)} from obj={_obj}, kwargs={kwargs}'
)
def copy_from(self: 'T', other: 'T') -> None:
"""Overwrite self by copying from another :class:`Document`.
:param other: the other Document to copy from
"""
self._data = cp.deepcopy(other._data)
def clear(self) -> None:
"""Clear all fields from this :class:`Document` to their default values."""
for f in self.non_empty_fields:
setattr(self._data, f, None)
def pop(self, *fields) -> None:
"""Clear some fields from this :class:`Document` to their default values.
:param fields: field names to clear.
"""
for f in fields:
if hasattr(self, f):
setattr(self._data, f, None)
@property
def non_empty_fields(self) -> Tuple[str]:
"""Get all non-emtpy fields of this :class:`Document`.
Non-empty fields are the fields with not-`None` and not-default values.
:return: field names in a tuple.
"""
return self._data._non_empty_fields
@property
def nbytes(self) -> int:
"""Return total bytes consumed by protobuf.
:return: number of bytes
"""
return len(bytes(self))
def __hash__(self):
return hash(self._data)
def __repr__(self):
content = str(self.non_empty_fields)
content += f' at {getattr(self, "id", id(self))}'
return f'<{self.__class__.__name__} {content.strip()}>'
def __bytes__(self):
return self.to_bytes()
def __eq__(self, other):
if type(self) is type(other):
return self._data == other._data
return False