-
Notifications
You must be signed in to change notification settings - Fork 152
Expand file tree
/
Copy pathgenerated_dataclass_patch.py
More file actions
79 lines (61 loc) · 2.21 KB
/
generated_dataclass_patch.py
File metadata and controls
79 lines (61 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from dataclasses import replace
from codegen.generated_dataclass import (
GeneratedDataclass,
GeneratedField,
GeneratedType,
)
def reorder_required_fields(models: dict[str, GeneratedDataclass]):
"""
Reorder fields in dataclasses so that required fields come first.
It's necessary for kwargs in the constructor to work correctly.
"""
for name, model in models.items():
if not model.fields:
continue
required_fields = [field for field in model.fields if _is_required(field)]
optional_fields = [field for field in model.fields if not _is_required(field)]
models[name] = replace(model, fields=required_fields + optional_fields)
def quote_recursive_references(models: dict[str, GeneratedDataclass]):
"""
If there is a cycle between two dataclasses, we need to quote one of them.
Example:
class Foo:
bar: Optional[Bar]
class Bar:
foo: "Foo"
"""
# see also _append_resolve_recursive_imports
if "jobs.ForEachTask" in models:
models["jobs.ForEachTask"] = _quote_recursive_references_for_model(
models["jobs.ForEachTask"],
references={"Task", "TaskParam"},
)
def _quote_recursive_references_for_model(
model: GeneratedDataclass,
references: set[str],
) -> GeneratedDataclass:
def update_type_name(type_name: GeneratedType):
if type_name.name in references:
return replace(
type_name,
name=f'"{type_name.name}"',
)
elif type_name.parameters:
return replace(
type_name,
parameters=[update_type_name(param) for param in type_name.parameters],
)
else:
return type_name
def update_field(field: GeneratedField):
return replace(
field,
type_name=update_type_name(field.type_name),
param_type_name=update_type_name(field.param_type_name),
)
return replace(
model,
fields=[update_field(field) for field in model.fields],
)
def _is_required(field: GeneratedField) -> bool:
return field.default is None and field.default_factory is None