-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathast_pass.py
More file actions
232 lines (200 loc) · 8.24 KB
/
ast_pass.py
File metadata and controls
232 lines (200 loc) · 8.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from semmle.python import ast
import semmle.python.master
import sys
from semmle.python.passes._pass import Pass
from semmle.util import get_analysis_major_version
__all__ = [ 'ASTPass' ]
class ASTPass(Pass):
'''Extract relations from AST.
Use AST.Node objects to guide _walking of AST'''
name = "ast"
def __init__(self):
self.offsets = get_offset_table()
#Entry point
def extract(self, root, writer):
try:
self.writer = writer
if root is None:
return
self._emit_variable(ast.Variable("__name__", root))
self._emit_variable(ast.Variable("__package__", root))
# Introduce special variable "$" for use by the points-to library.
self._emit_variable(ast.Variable("$", root))
writer.write_tuple(u'py_extracted_version', 'gs', root.trap_name, get_analysis_major_version())
self._walk(root, None, 0, root, None)
finally:
self.writer = None
#Tree _walkers
def _get_walker(self, node):
if isinstance(node, list):
return self._walk_list
elif isinstance(node, ast.AstBase):
return self._walk_node
else:
return self._emit_primitive
def _walk(self, node, parent, index, scope, description):
self._get_walker(node)(node, parent, index, scope, description)
def _walk_node(self, node, parent, index, scope, _unused):
self._emit_node(node, parent, index, scope)
if type(node) is ast.Name:
assert (hasattr(node, 'variable') and
type(node.variable) is ast.Variable), (node, parent, index, scope)
if type(node) in (ast.Class, ast.Function):
scope = node
# For scopes with a `from ... import *` statement introduce special variable "*" for use by the points-to library.
if isinstance(node, ast.ImportFrom):
self._emit_variable(ast.Variable("*", scope))
for field_name, desc, child_node in iter_fields(node):
try:
index = self.offsets[(type(node).__name__, field_name)]
self._walk(child_node, node, index, scope, desc)
except ConsistencyError:
ex = sys.exc_info()[1]
ex.message += ' in ' + type(node).__name__
if hasattr(node, 'rewritten') and node.rewritten:
ex.message += '(rewritten)'
ex.message += '.' + field_name
raise
def _walk_list(self, node, parent, index, scope, description):
assert description.is_list(), description
if len(node) == 0:
return
else:
self._emit_list(node, parent, index, description)
for i, child in enumerate(node):
self._get_walker(child)(child, node, i, scope, description.item_type)
#Emitters
def _emit_node(self, ast_node, parent, index, scope):
t = type(ast_node)
node = _ast_nodes[t.__name__]
#Ensure all stmts have a list as a parent.
if isinstance(ast_node, ast.stmt):
assert isinstance(parent, list), (ast_node, parent)
if node.is_sub_type():
rel_name = node.super_type.relation_name()
shared_parent = not node.super_type.unique_parent
else:
rel_name = node.relation_name()
shared_parent = node.parents is None or not node.unique_parent
if rel_name[-1] != 's':
rel_name += 's'
if t.__mro__[1] in (ast.cmpop, ast.operator, ast.expr_context, ast.unaryop, ast.boolop):
#These nodes may be used more than once, but must have a
#unique id for each occurrence in the AST
fields = [ self.writer.get_unique_id() ]
fmt = 'r'
else:
fields = [ ast_node ]
fmt = 'n'
if node.is_sub_type():
fields.append(node.index)
fmt += 'd'
if parent:
fields.append(parent)
fmt += 'n'
if shared_parent:
fields.append(index)
fmt += 'd'
self.writer.write_tuple(rel_name, fmt, *fields)
if t.__mro__[1] in (ast.expr, ast.stmt):
self.writer.write_tuple(u'py_scopes', 'nn', ast_node, scope)
def _emit_variable(self, ast_node):
self.writer.write_tuple(u'variable', 'nns', ast_node, ast_node.scope, ast_node.id)
def _emit_name(self, ast_node, parent):
self._emit_variable(ast_node)
self.writer.write_tuple(u'py_variables', 'nn', ast_node, parent)
def _emit_primitive(self, val, parent, index, scope, description):
if val is None or val is False:
return
if isinstance(val, ast.Variable):
self._emit_name(val, parent)
return
assert not isinstance(val, ast.AstBase)
rel = description.relation_name()
if val is True:
if description.unique_parent:
self.writer.write_tuple(rel, 'n', parent)
else:
self.writer.write_tuple(rel, 'nd', parent, index)
else:
f = format_for_primitive(val, description)
if description.unique_parent:
self.writer.write_tuple(rel, f + 'n', val, parent)
else:
self.writer.write_tuple(rel, f + 'nd', val, parent, index)
def _emit_list(self, node, parent, index, description):
rel_name = description.relation_name()
if description.unique_parent:
self.writer.write_tuple(rel_name, 'nn', node, parent)
else:
self.writer.write_tuple(rel_name, 'nnd', node, parent, index)
_ast_nodes = semmle.python.master.all_nodes()
if get_analysis_major_version() < 3:
_ast_nodes['TryExcept'] = _ast_nodes['Try']
_ast_nodes['TryFinally'] = _ast_nodes['Try']
class ConsistencyError(Exception):
def __str__(self):
return self.message
def iter_fields(node):
desc = _ast_nodes[type(node).__name__]
for name, description, _, _, _ in desc.fields:
if hasattr(node, name):
yield name, description, getattr(node, name)
NUMBER_TYPES = (int, float)
def check_matches(node, node_type, owner, field):
if node_type is list:
if node.is_list():
return
else:
for t in node_type.__mro__:
if t.__name__ == node.__name__:
return
if node_type in NUMBER_TYPES and node.__name__ == 'number':
return
raise ConsistencyError("Found %s expected %s for field %s of %s" %
(node_type.__name__, node.__name__, field, owner.__name__))
def get_offset_table():
'''Returns mapping of (class_name, field_name)
pairs to offsets (in relation)'''
table = {}
nodes = _ast_nodes.values()
for node in nodes:
for field, _, offset, _, _, _ in node.layout:
table[(node.__name__, field)] = offset
try_node = _ast_nodes['Try']
for field, _, offset, _, _, _ in try_node.layout:
table[('TryFinally', field)] = offset
table[('TryExcept', field)] = offset
return table
def format_for_primitive(val, description):
if isinstance(val, str):
return 'u'
elif isinstance(val, bytes):
return 'b'
elif description.__name__ == 'int':
return 'd'
else:
return 'q'
class ASTVisitor(object):
"""
A node visitor base class that walks the abstract syntax tree and calls a
visitor function for every node found. This function may return a value
which is forwarded by the `visit` method.
This class is meant to be subclassed, with the subclass adding visitor
methods.
The visitor functions for the nodes are ``'visit_'`` + class name of the node.
"""
def _get_visit_method(self, node):
method = 'visit_' + node.__class__.__name__
return getattr(self, method, self.generic_visit)
def visit(self, node):
"""Visit a node."""
self._get_visit_method(node)(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
if isinstance(node, ast.AstBase):
for _, _, child in iter_fields(node):
self.visit(child)
elif isinstance(node, list):
for item in node:
self._get_visit_method(item)(item)