-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnormalize.py
More file actions
250 lines (203 loc) · 8.67 KB
/
normalize.py
File metadata and controls
250 lines (203 loc) · 8.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
# Copyright (c) 2026 Den Rozhnovskiy
from __future__ import annotations
import ast
import copy
import hashlib
from ast import AST
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast
from .meta_markers import CFG_META_PREFIX
if TYPE_CHECKING:
from collections.abc import Sequence
@dataclass(frozen=True, slots=True)
class NormalizationConfig:
ignore_docstrings: bool = True
ignore_type_annotations: bool = True
normalize_attributes: bool = True
normalize_constants: bool = True
normalize_names: bool = True
class AstNormalizer(ast.NodeTransformer):
__slots__ = ("cfg",)
def __init__(self, cfg: NormalizationConfig):
super().__init__()
self.cfg = cfg
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
return self._visit_func(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
return self._visit_func(node)
def _visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST:
# Drop docstring
if self.cfg.ignore_docstrings and node.body:
first = node.body[0]
if (
isinstance(first, ast.Expr)
and isinstance(first.value, ast.Constant)
and isinstance(first.value.value, str)
):
node.body = node.body[1:]
if self.cfg.ignore_type_annotations:
node.returns = None
args = node.args
for a in getattr(args, "posonlyargs", []):
a.annotation = None
for a in args.args:
a.annotation = None
for a in args.kwonlyargs:
a.annotation = None
if args.vararg:
args.vararg.annotation = None
if args.kwarg:
args.kwarg.annotation = None
return self.generic_visit(node)
def visit_arg(self, node: ast.arg) -> ast.arg:
if self.cfg.ignore_type_annotations:
node.annotation = None
return node
def visit_Name(self, node: ast.Name) -> ast.Name:
# Call targets are handled in visit_Call to avoid erasing callee symbols.
if self.cfg.normalize_names and not _is_semantic_marker_name(node.id):
node.id = "_VAR_"
return node
def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
new_node = self.generic_visit(node)
assert isinstance(new_node, ast.Attribute)
if self.cfg.normalize_attributes:
new_node.attr = "_ATTR_"
return new_node
def visit_Constant(self, node: ast.Constant) -> ast.Constant:
if self.cfg.normalize_constants:
node.value = "_CONST_"
return node
def visit_Call(self, node: ast.Call) -> ast.Call:
node.func = self._visit_call_target(node.func)
node.args = [cast("ast.expr", self.visit(arg)) for arg in node.args]
for kw in node.keywords:
kw.value = cast("ast.expr", self.visit(kw.value))
return node
def _visit_call_target(self, node: ast.expr) -> ast.expr:
# Keep symbolic call targets intact to avoid conflating different APIs.
if isinstance(node, ast.Name):
return node
if isinstance(node, ast.Attribute):
value = node.value
if isinstance(value, (ast.Name, ast.Attribute)):
node.value = self._visit_call_target(value)
else:
node.value = cast("ast.expr", self.visit(value))
return node
return cast("ast.expr", self.visit(node))
def visit_AugAssign(self, node: ast.AugAssign) -> AST:
# Normalize x += 1 to x = x + 1
# This allows detecting clones where one uses += and another uses = +
# We transform AugAssign(target, op, value) to Assign([target],
# BinOp(target, op, value))
# Deepcopy target to avoid reuse issues in the AST
target_load = copy.deepcopy(node.target)
# Ensure context is Load() for the right-hand side usage
if hasattr(target_load, "ctx"):
target_load.ctx = ast.Load()
new_node = ast.Assign(
targets=[node.target],
value=ast.BinOp(left=target_load, op=node.op, right=node.value),
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=getattr(node, "end_lineno", None),
end_col_offset=getattr(node, "end_col_offset", None),
)
return self.generic_visit(new_node)
def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
new_node = self.generic_visit(node)
assert isinstance(new_node, ast.UnaryOp)
if isinstance(new_node.op, ast.Not):
operand = new_node.operand
if (
isinstance(operand, ast.Compare)
and len(operand.ops) == 1
and len(operand.comparators) == 1
):
op = operand.ops[0]
if isinstance(op, ast.In):
cmp = ast.Compare(
left=operand.left,
ops=[ast.NotIn()],
comparators=operand.comparators,
)
return ast.copy_location(cmp, new_node)
if isinstance(op, ast.Is):
cmp = ast.Compare(
left=operand.left,
ops=[ast.IsNot()],
comparators=operand.comparators,
)
return ast.copy_location(cmp, new_node)
return new_node
def visit_BinOp(self, node: ast.BinOp) -> ast.AST:
new_node = self.generic_visit(node)
assert isinstance(new_node, ast.BinOp)
if not isinstance(
new_node.op, (ast.Add, ast.Mult, ast.BitOr, ast.BitAnd, ast.BitXor)
):
return new_node
if not (
_is_proven_commutative_operand(new_node.left, new_node.op)
and _is_proven_commutative_operand(new_node.right, new_node.op)
):
return new_node
left_key = _expr_sort_key(new_node.left)
right_key = _expr_sort_key(new_node.right)
if right_key < left_key:
new_node.left, new_node.right = new_node.right, new_node.left
return new_node
def _expr_sort_key(node: ast.AST) -> str:
return ast.dump(node, annotate_fields=True, include_attributes=False)
def _is_semantic_marker_name(name: str) -> bool:
return name.startswith(CFG_META_PREFIX)
def _is_proven_commutative_operand(node: ast.AST, op: ast.operator) -> bool:
if isinstance(node, ast.Constant):
return _is_proven_commutative_constant(node.value, op)
if isinstance(node, ast.BinOp) and type(node.op) is type(op):
return _is_proven_commutative_operand(
node.left, op
) and _is_proven_commutative_operand(node.right, op)
return False
def _is_proven_commutative_constant(value: object, op: ast.operator) -> bool:
if isinstance(op, (ast.BitOr, ast.BitAnd, ast.BitXor)):
return isinstance(value, int) and not isinstance(value, bool)
if isinstance(op, (ast.Add, ast.Mult)):
return isinstance(value, (int, float, complex)) and not isinstance(value, bool)
return False
def normalized_ast_dump_from_list(
nodes: Sequence[ast.AST],
cfg: NormalizationConfig,
*,
normalizer: AstNormalizer | None = None,
) -> str:
"""
Dump a list of AST nodes after normalization.
The normalizer works on deep-copied nodes so callers can safely reuse
the original AST for downstream metrics and reporting passes.
"""
active_normalizer = normalizer or AstNormalizer(cfg)
dumps: list[str] = []
for node in nodes:
# Fingerprints ignore location attributes, so we skip location repair.
new_node = active_normalizer.visit(copy.deepcopy(node))
assert isinstance(new_node, ast.AST)
dumps.append(ast.dump(new_node, annotate_fields=True, include_attributes=False))
return ";".join(dumps)
def _normalized_stmt_dump(stmt: ast.stmt, normalizer: AstNormalizer) -> str:
normalized = normalizer.visit(stmt)
assert isinstance(normalized, ast.AST)
return ast.dump(normalized, annotate_fields=True, include_attributes=False)
def stmt_hashes(statements: Sequence[ast.stmt], cfg: NormalizationConfig) -> list[str]:
normalizer = AstNormalizer(cfg)
return [
hashlib.sha1(
_normalized_stmt_dump(stmt, normalizer).encode("utf-8")
).hexdigest()
for stmt in statements
]