-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathutils.py
More file actions
181 lines (150 loc) · 6.57 KB
/
utils.py
File metadata and controls
181 lines (150 loc) · 6.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from enum import Enum
from pathlib import Path
from typing import Optional, Any
from libcst import MetadataDependent, matchers
from libcst.codemod import CodemodContext
from libcst.matchers import MatcherDecoratableTransformer
import libcst as cst
class BaseType(Enum):
"""
An enumeration representing the base literal types in Python.
"""
NUMBER = 1
LIST = 2
STRING = 3
BYTES = 4
# pylint: disable-next=R0911
def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]:
"""
Tries to infer if the resulting type of a given expression is one of the base literal types.
"""
# The current implementation covers some common cases and is in no way complete
match node:
case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call(
func=cst.Name("int")
) | cst.Call(func=cst.Name("float")) | cst.Call(
func=cst.Name("abs")
) | cst.Call(
func=cst.Name("len")
):
return BaseType.NUMBER
case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp():
return BaseType.LIST
case cst.Call(func=cst.Name("str")) | cst.FormattedString():
return BaseType.STRING
case cst.SimpleString():
if "b" in node.prefix.lower():
return BaseType.BYTES
return BaseType.STRING
case cst.ConcatenatedString():
return infer_expression_type(node.left)
case cst.BinaryOperation(operator=cst.Add()):
return infer_expression_type(node.left) or infer_expression_type(node.right)
case cst.BinaryOperation(operator=cst.Modulo()):
return infer_expression_type(node.left) or infer_expression_type(node.right)
case cst.IfExp():
if_true = infer_expression_type(node.body)
or_else = infer_expression_type(node.orelse)
if if_true == or_else:
return if_true
return None
class SequenceExtension:
def __init__(self, sequence: list[cst.CSTNode]) -> None:
self.sequence = sequence
class Append(SequenceExtension):
pass
class Prepend(SequenceExtension):
pass
class ReplaceNodes(cst.CSTTransformer):
"""
Replace nodes with their corresponding values in a given dict. The replacements dictionary should either contain a mapping from a node to another node, RemovalSentinel, or FlattenSentinel to be replaced, or a dict mapping each attribute, by name, to a new value. Additionally if the attribute is a sequence, you may pass Append(l)/Prepend(l), where l is a list of nodes, to append or prepend, respectively.
"""
def __init__(
self,
replacements: dict[
cst.CSTNode,
cst.CSTNode | cst.FlattenSentinel | cst.RemovalSentinel | dict[str, Any],
],
):
self.replacements = replacements
def on_leave(self, original_node, updated_node):
if original_node in self.replacements.keys():
replacement = self.replacements[original_node]
match replacement:
case dict():
changes_dict = {}
for key, value in replacement.items():
match value:
case Prepend():
changes_dict[key] = value.sequence + [
*getattr(updated_node, key)
]
case Append():
changes_dict[key] = [
*getattr(updated_node, key)
] + value.sequence
case _:
changes_dict[key] = value
return updated_node.with_changes(**changes_dict)
case cst.CSTNode() | cst.RemovalSentinel() | cst.FlattenSentinel():
return replacement
return updated_node
class MetadataPreservingTransformer(
MatcherDecoratableTransformer, cst.MetadataDependent
):
"""
The CSTTransformer equivalent of ContextAwareVisitor. Will preserve metadata passed through a context. You should not chain more than one of these, otherwise metadata will not reflect the state of the tree.
"""
def __init__(self, context: CodemodContext) -> None:
MetadataDependent.__init__(self)
MatcherDecoratableTransformer.__init__(self)
self.context = context
dependencies = self.get_inherited_dependencies()
if dependencies:
wrapper = self.context.wrapper
if wrapper is None:
raise ValueError(
f"Attempting to instantiate {self.__class__.__name__} outside of "
+ "an active transform. This means that metadata hasn't been "
+ "calculated and we cannot successfully create this visitor."
)
for dep in dependencies:
if dep not in wrapper._metadata:
raise ValueError(
f"Attempting to access metadata {dep.__name__} that was not a "
+ "declared dependency of parent transform! This means it is "
+ "not possible to compute this value. Please ensure that all "
+ f"parent transforms of {self.__class__.__name__} declare "
+ f"{dep.__name__} as a metadata dependency."
)
self.metadata = {dep: wrapper._metadata[dep] for dep in dependencies}
def is_django_settings_file(file_path: Path):
if "settings.py" not in file_path.name:
return False
# the most telling fact is the presence of a manage.py file in the parent directory
if file_path.parent.parent.is_dir():
return "manage.py" in (f.name for f in file_path.parent.parent.iterdir())
return False
def is_setup_py_file(file_path: Path):
return file_path.name == "setup.py"
def get_call_name(call: cst.Call) -> str:
"""
Extracts the full name from a function call
"""
# is it a composite name? e.g. a.b.c
if matchers.matches(call.func, matchers.Attribute()):
return call.func.attr.value
# It's a simple Name
return call.func.value
def get_function_name_node(call: cst.Call) -> Optional[cst.Name]:
match call.func:
case cst.Name():
return call.func
case cst.Attribute():
return call.func.attr
return None
def is_assigned_to_True(original_node: cst.Assign):
return (
isinstance(original_node.value, cst.Name)
and original_node.value.value == "True"
)