-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathbase_visitor.py
More file actions
133 lines (113 loc) · 4.86 KB
/
base_visitor.py
File metadata and controls
133 lines (113 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from functools import cache
from typing import ClassVar, Collection, cast
import libcst as cst
from libcst import MetadataDependent
from libcst._position import CodePosition, CodeRange
from libcst.codemod import ContextAwareVisitor, VisitorBasedCodemodCommand
from libcst.metadata import PositionProvider, ProviderT
from codemodder.result import Result
# TODO: this should just be part of BaseTransformer and BaseVisitor?
class UtilsMixin(MetadataDependent):
METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = (PositionProvider,)
def __init__(
self,
results: list[Result] | None,
line_exclude: list[int],
line_include: list[int],
):
self.results = results
self.line_exclude = line_exclude
self.line_include = line_include
def filter_by_result(self, node: cst.CSTNode) -> bool:
# Codemods with detectors will only run their transformations if there are results.
return self.results is None or any(self.results_for_node(node))
@cache
def results_for_node(self, node: cst.CSTNode) -> list[Result]:
pos_to_match = self.node_position(node)
return (
[
result
for result in self.results
if result.match_location(pos_to_match, node)
]
if self.results
else []
)
def filter_by_path_includes_or_excludes(self, pos_to_match):
"""
Returns True if the node, whose position in the file is pos_to_match, matches any of the lines specified in the path-includes or path-excludes flags.
"""
# excludes takes precedence if defined
if self.line_exclude:
return not any(match_line(pos_to_match, line) for line in self.line_exclude)
if self.line_include:
return any(match_line(pos_to_match, line) for line in self.line_include)
return True
def node_is_selected(self, node) -> bool:
pos_to_match = self.node_position(node)
return self.filter_by_result(node) and self.filter_by_path_includes_or_excludes(
pos_to_match
)
def node_is_selected_by_line_only(self, node) -> bool:
pos_to_match = self.node_position(node)
return self.filter_by_result_line_only(
pos_to_match
) and self.filter_by_path_includes_or_excludes(pos_to_match)
def filter_by_result_line_only(self, pos_to_match) -> bool:
# Codemods with detectors will only run their transformations if there are results.
return self.results is None or any(
pos_to_match.start.line >= location.start.line
and pos_to_match.end.line <= location.end.line
for r in self.results
for location in r.locations
)
def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
match node:
case cst.FunctionDef():
# By default a function's position includes the entire
# function definition. Instead, we will only use the first line
# of the function definition.
params_end = cast(
CodeRange, self.get_metadata(PositionProvider, node.params)
).end
return CodeRange(
start=cast(
CodeRange, self.get_metadata(PositionProvider, node)
).start,
end=CodePosition(params_end.line, params_end.column + 1),
)
case _:
return cast(CodeRange, self.get_metadata(PositionProvider, node))
def lineno_for_node(self, node):
return self.node_position(node).start.line
def code(self, node: cst.CSTNode) -> str:
"""
Only a cst.Module node has a `code` attribute which converts the node
back to the original code as a str. To get the code for any node,
the suggested approach is to wrap this node in a `cst.Module` node.
"""
module = cst.Module(body=[cst.SimpleStatementLine(body=[cst.Expr(value=node)])])
return module.code
class BaseTransformer(VisitorBasedCodemodCommand, UtilsMixin):
def __init__(
self,
context,
results: list[Result] | None,
line_include: list[int],
line_exclude: list[int],
):
VisitorBasedCodemodCommand.__init__(self, context)
UtilsMixin.__init__(self, results, line_exclude, line_include)
class BaseVisitor(ContextAwareVisitor, UtilsMixin):
def __init__(
self,
context,
results: list[Result] | None,
line_include: list[int],
line_exclude: list[int],
):
ContextAwareVisitor.__init__(self, context)
UtilsMixin.__init__(self, results, line_exclude, line_include)
def match_line(pos, line):
return pos.start.line == line and pos.end.line == line