forked from python/mypy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataclasses.py
More file actions
388 lines (337 loc) · 16.1 KB
/
dataclasses.py
File metadata and controls
388 lines (337 loc) · 16.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
"""Plugin that provides support for dataclasses."""
from typing import Dict, List, Set, Tuple, Optional
from typing_extensions import Final
from mypy.nodes import (
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
Context, Expression, FuncDef, JsonDict, NameExpr, RefExpr,
SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode
)
from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method, _get_decorator_bool_argument
from mypy.types import Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type
from mypy.server.trigger import make_wildcard_trigger
# The set of decorators that generate dataclasses.
dataclass_makers = {
'dataclass',
'dataclasses.dataclass',
} # type: Final
SELF_TVAR_NAME = '_DT' # type: Final
class DataclassAttribute:
def __init__(
self,
name: str,
is_in_init: bool,
is_init_var: bool,
has_default: bool,
line: int,
column: int,
) -> None:
self.name = name
self.is_in_init = is_in_init
self.is_init_var = is_init_var
self.has_default = has_default
self.line = line
self.column = column
def to_argument(self, info: TypeInfo) -> Argument:
return Argument(
variable=self.to_var(info),
type_annotation=info[self.name].type,
initializer=None,
kind=ARG_OPT if self.has_default else ARG_POS,
)
def to_var(self, info: TypeInfo) -> Var:
return Var(self.name, info[self.name].type)
def serialize(self) -> JsonDict:
return {
'name': self.name,
'is_in_init': self.is_in_init,
'is_init_var': self.is_init_var,
'has_default': self.has_default,
'line': self.line,
'column': self.column,
}
@classmethod
def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'DataclassAttribute':
return cls(**data)
class DataclassTransformer:
def __init__(self, ctx: ClassDefContext) -> None:
self._ctx = ctx
def transform(self) -> None:
"""Apply all the necessary transformations to the underlying
dataclass so as to ensure it is fully type checked according
to the rules in PEP 557.
"""
ctx = self._ctx
info = self._ctx.cls.info
attributes = self.collect_attributes()
if attributes is None:
# Some definitions are not ready, defer() should be already called.
return
for attr in attributes:
node = info.get(attr.name)
if node is None:
# Nodes of superclass InitVars not used in __init__ cannot be reached.
assert attr.is_init_var and not attr.is_in_init
continue
if node.type is None:
ctx.api.defer()
return
decorator_arguments = {
'init': _get_decorator_bool_argument(self._ctx, 'init', True),
'eq': _get_decorator_bool_argument(self._ctx, 'eq', True),
'order': _get_decorator_bool_argument(self._ctx, 'order', False),
'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False),
}
# If there are no attributes, it may be that the semantic analyzer has not
# processed them yet. In order to work around this, we can simply skip generating
# __init__ if there are no attributes, because if the user truly did not define any,
# then the object default __init__ with an empty signature will be present anyway.
if (decorator_arguments['init'] and
('__init__' not in info.names or info.names['__init__'].plugin_generated) and
attributes):
add_method(
ctx,
'__init__',
args=[attr.to_argument(info) for attr in attributes if attr.is_in_init],
return_type=NoneType(),
)
if (decorator_arguments['eq'] and info.get('__eq__') is None or
decorator_arguments['order']):
# Type variable for self types in generated methods.
obj_type = ctx.api.named_type('__builtins__.object')
self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, info.fullname() + '.' + SELF_TVAR_NAME,
[], obj_type)
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
# Add an eq method, but only if the class doesn't already have one.
if decorator_arguments['eq'] and info.get('__eq__') is None:
for method_name in ['__eq__', '__ne__']:
# The TVar is used to enforce that "other" must have
# the same type as self (covariant). Note the
# "self_type" parameter to add_method.
obj_type = ctx.api.named_type('__builtins__.object')
cmp_tvar_def = TypeVarDef(SELF_TVAR_NAME, info.fullname() + '.' + SELF_TVAR_NAME,
-1, [], obj_type)
cmp_other_type = TypeVarType(cmp_tvar_def)
cmp_return_type = ctx.api.named_type('__builtins__.bool')
add_method(
ctx,
method_name,
args=[Argument(Var('other', cmp_other_type), cmp_other_type, None, ARG_POS)],
return_type=cmp_return_type,
self_type=cmp_other_type,
tvar_def=cmp_tvar_def,
)
# Add <, >, <=, >=, but only if the class has an eq method.
if decorator_arguments['order']:
if not decorator_arguments['eq']:
ctx.api.fail('eq must be True if order is True', ctx.cls)
for method_name in ['__lt__', '__gt__', '__le__', '__ge__']:
# Like for __eq__ and __ne__, we want "other" to match
# the self type.
obj_type = ctx.api.named_type('__builtins__.object')
order_tvar_def = TypeVarDef(SELF_TVAR_NAME, info.fullname() + '.' + SELF_TVAR_NAME,
-1, [], obj_type)
order_other_type = TypeVarType(order_tvar_def)
order_return_type = ctx.api.named_type('__builtins__.bool')
order_args = [
Argument(Var('other', order_other_type), order_other_type, None, ARG_POS)
]
existing_method = info.get(method_name)
if existing_method is not None and not existing_method.plugin_generated:
assert existing_method.node
ctx.api.fail(
'You may not have a custom %s method when order=True' % method_name,
existing_method.node,
)
add_method(
ctx,
method_name,
args=order_args,
return_type=order_return_type,
self_type=order_other_type,
tvar_def=order_tvar_def,
)
if decorator_arguments['frozen']:
self._freeze(attributes)
self.reset_init_only_vars(info, attributes)
info.metadata['dataclass'] = {
'attributes': [attr.serialize() for attr in attributes],
'frozen': decorator_arguments['frozen'],
}
def reset_init_only_vars(self, info: TypeInfo, attributes: List[DataclassAttribute]) -> None:
"""Remove init-only vars from the class and reset init var declarations."""
for attr in attributes:
if attr.is_init_var:
if attr.name in info.names:
del info.names[attr.name]
else:
# Nodes of superclass InitVars not used in __init__ cannot be reached.
assert attr.is_init_var and not attr.is_in_init
for stmt in info.defn.defs.body:
if isinstance(stmt, AssignmentStmt) and stmt.unanalyzed_type:
lvalue = stmt.lvalues[0]
if isinstance(lvalue, NameExpr) and lvalue.name == attr.name:
# Reset node so that another semantic analysis pass will
# recreate a symbol node for this attribute.
lvalue.node = None
def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
"""Collect all attributes declared in the dataclass and its parents.
All assignments of the form
a: SomeType
b: SomeOtherType = ...
are collected.
"""
# First, collect attributes belonging to the current class.
ctx = self._ctx
cls = self._ctx.cls
attrs = [] # type: List[DataclassAttribute]
known_attrs = set() # type: Set[str]
for stmt in cls.defs.body:
# Any assignment that doesn't use the new type declaration
# syntax can be ignored out of hand.
if not (isinstance(stmt, AssignmentStmt) and stmt.new_syntax):
continue
# a: int, b: str = 1, 'foo' is not supported syntax so we
# don't have to worry about it.
lhs = stmt.lvalues[0]
if not isinstance(lhs, NameExpr):
continue
sym = cls.info.names.get(lhs.name)
if sym is None:
# This name is likely blocked by a star import. We don't need to defer because
# defer() is already called by mark_incomplete().
continue
node = sym.node
if isinstance(node, PlaceholderNode):
# This node is not ready yet.
return None
assert isinstance(node, Var)
# x: ClassVar[int] is ignored by dataclasses.
if node.is_classvar:
continue
# x: InitVar[int] is turned into x: int and is removed from the class.
is_init_var = False
node_type = get_proper_type(node.type)
if (isinstance(node_type, Instance) and
node_type.type.fullname() == 'dataclasses.InitVar'):
is_init_var = True
node.type = node_type.args[0]
has_field_call, field_args = _collect_field_args(stmt.rvalue)
is_in_init_param = field_args.get('init')
if is_in_init_param is None:
is_in_init = True
else:
is_in_init = bool(ctx.api.parse_bool(is_in_init_param))
has_default = False
# Ensure that something like x: int = field() is rejected
# after an attribute with a default.
if has_field_call:
has_default = 'default' in field_args or 'default_factory' in field_args
# All other assignments are already type checked.
elif not isinstance(stmt.rvalue, TempNode):
has_default = True
known_attrs.add(lhs.name)
attrs.append(DataclassAttribute(
name=lhs.name,
is_in_init=is_in_init,
is_init_var=is_init_var,
has_default=has_default,
line=stmt.line,
column=stmt.column,
))
# Next, collect attributes belonging to any class in the MRO
# as long as those attributes weren't already collected. This
# makes it possible to overwrite attributes in subclasses.
# copy() because we potentially modify all_attrs below and if this code requires debugging
# we'll have unmodified attrs laying around.
all_attrs = attrs.copy()
init_method = cls.info.get_method('__init__')
for info in cls.info.mro[1:-1]:
if 'dataclass' not in info.metadata:
continue
super_attrs = []
# Each class depends on the set of attributes in its dataclass ancestors.
ctx.api.add_plugin_dependency(make_wildcard_trigger(info.fullname()))
for data in info.metadata['dataclass']['attributes']:
name = data['name'] # type: str
if name not in known_attrs:
attr = DataclassAttribute.deserialize(info, data)
if attr.is_init_var and isinstance(init_method, FuncDef):
# InitVars are removed from classes so, in order for them to be inherited
# properly, we need to re-inject them into subclasses' sym tables here.
# To do that, we look 'em up from the parents' __init__. These variables
# are subsequently removed from the sym table at the end of
# DataclassTransformer.transform.
for arg, arg_name in zip(init_method.arguments, init_method.arg_names):
if arg_name == attr.name:
cls.info.names[attr.name] = SymbolTableNode(MDEF, arg.variable)
known_attrs.add(name)
super_attrs.append(attr)
elif all_attrs:
# How early in the attribute list an attribute appears is determined by the
# reverse MRO, not simply MRO.
# See https://docs.python.org/3/library/dataclasses.html#inheritance for
# details.
for attr in all_attrs:
if attr.name == name:
all_attrs.remove(attr)
super_attrs.append(attr)
break
all_attrs = super_attrs + all_attrs
# Ensure that arguments without a default don't follow
# arguments that have a default.
found_default = False
for attr in all_attrs:
# If we find any attribute that is_in_init but that
# doesn't have a default after one that does have one,
# then that's an error.
if found_default and attr.is_in_init and not attr.has_default:
# If the issue comes from merging different classes, report it
# at the class definition point.
context = (Context(line=attr.line, column=attr.column) if attr in attrs
else ctx.cls)
ctx.api.fail(
'Attributes without a default cannot follow attributes with one',
context,
)
found_default = found_default or (attr.has_default and attr.is_in_init)
return all_attrs
def _freeze(self, attributes: List[DataclassAttribute]) -> None:
"""Converts all attributes to @property methods in order to
emulate frozen classes.
"""
info = self._ctx.cls.info
for attr in attributes:
sym_node = info.names.get(attr.name)
if sym_node is not None:
var = sym_node.node
assert isinstance(var, Var)
var.is_property = True
else:
var = attr.to_var(info)
var.info = info
var.is_property = True
var._fullname = info.fullname() + '.' + var.name()
info.names[var.name()] = SymbolTableNode(MDEF, var)
def dataclass_class_maker_callback(ctx: ClassDefContext) -> None:
"""Hooks into the class typechecking process to add support for dataclasses.
"""
transformer = DataclassTransformer(ctx)
transformer.transform()
def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
"""
if (
isinstance(expr, CallExpr) and
isinstance(expr.callee, RefExpr) and
expr.callee.fullname == 'dataclasses.field'
):
# field() only takes keyword arguments.
args = {}
for name, arg in zip(expr.arg_names, expr.args):
assert name is not None
args[name] = arg
return True, args
return False, {}