-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinner_emitter.py
More file actions
142 lines (126 loc) · 4.84 KB
/
Copy pathinner_emitter.py
File metadata and controls
142 lines (126 loc) · 4.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import Any, Dict, List, Tuple
from onnx import AttributeProto
from .annotations import ELEMENT_TYPE_NAME
from .emitter import BaseEmitter
from .translate import Translater
class InnerEmitter(BaseEmitter):
"""
Converts event into proper code.
"""
def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
"""
Renders an attribute value into a string.
:param value: value to converter
:return: rows to append before, actual value
"""
if value[0].type == AttributeProto.GRAPH:
tr = Translater(value[0].g, emitter=self)
rows = tr.export(as_str=False, single_line=False)
new_rows = [f"def _make_local_graph_{value[0].name}():"]
for line in rows:
if "make_model" in line:
break
new_rows.append(" " + line)
new_rows.append(" return graph")
new_rows.append(f"{value[0].name} = _make_local_graph_{value[0].name}()")
return new_rows, value[0].name
return super().render_attribute_value(value)
def join(self, rows: List[str], single_line: bool = False) -> str:
"Returns the separators. `single_line` is unused."
return "\n".join(rows)
def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
lines = ["opset_imports = ["]
opsets = kwargs.get("opsets", {})
for k, v in opsets.items():
lines.append(f" make_opsetid({k!r}, {v!r}),")
lines.append("]")
return lines
def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]:
lines = [
"model = make_model(",
" graph,",
" functions=functions,",
" opset_imports=opset_imports",
")",
]
return lines
def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
lines = [
"inputs = []",
"outputs = []",
"nodes = []",
"initializers = []",
"sparse_initializers = []",
"functions = []",
]
return lines
def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
name = kwargs.get("name", "noname")
lines = [
"graph = make_graph(",
" nodes,",
f" {name!r},",
" inputs,",
" outputs,",
" initializers,",
" sparse_initializer=sparse_initializers,",
")",
]
return lines
def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
name = kwargs["name"]
value = kwargs["value"]
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
return [
"initializers.append(",
" from_array(",
f" np.array({value.tolist()}, dtype=np.{sdtype}),",
f" name={name!r}",
" )",
")",
]
def _emit_io(self, container: str, **kwargs: Dict[str, Any]) -> List[str]:
name = kwargs["name"]
elem_type = kwargs.get("elem_type", None)
shape = kwargs.get("shape", None)
if elem_type and shape:
return [
f"{container}.append(make_tensor_value_info({name!r}, TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r}))"
]
if elem_type:
return [
f"{container}.append(make_tensor_value_info({name!r}, TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape=[]))"
]
return [
f"{container}.append(make_tensor_value_info({name!r}, TensorProto.UNDEFINED, []))"
]
def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
return self._emit_io("inputs", **kwargs)
def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
return self._emit_io("outputs", **kwargs)
def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
op_type = kwargs["op_type"]
inputs = kwargs["inputs"]
outputs = kwargs["outputs"]
if kwargs.get("domain", "") != "":
domain = kwargs["domain"]
before_lines = []
lines = [
"nodes.append(",
" make_node(",
f" {op_type!r},",
f" {inputs},",
f" {outputs},",
]
domain = kwargs.get("domain", "")
if domain:
lines.append(f" domain={domain!r},")
atts = kwargs.get("atts", {})
for k, v in atts.items():
before, value = self.render_attribute_value(v)
before_lines.extend(before)
lines.append(f" {k}={value},")
lines[-1] = lines[-1][:-1]
lines.extend([" )", ")"])
return before_lines + lines