-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpython.lua
More file actions
172 lines (144 loc) · 6.36 KB
/
python.lua
File metadata and controls
172 lines (144 loc) · 6.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
local string_utils = require("copy-python-path.utils.string")
local M = {}
--- Checks if an input string is a valid Python identifier.
---@param input string
M.is_valid_symbol_name = function(input)
return input:match("^[a-zA-Z_][a-zA-Z0-9_]*$") ~= nil
end
--- Finds a symbol name that can be imported from another module from a single line of code.
---@param code string a single line of Python source code
---@return string|nil symbol the symbol name, or nil if no match
---@return string|nil indent the indentation of the statement, or nil if no match
M.find_importable_symbol = function(code)
-- module-level variables (which shouldn't have indentation)
local module_level_var_pattern = "^([a-zA-Z_][a-zA-Z0-9_]*)%s*="
local module_level_var = code:match(module_level_var_pattern)
if module_level_var then
return module_level_var, ""
end
local patterns_with_indent = {
"^(%s*)class%s+([a-zA-Z_][a-zA-Z0-9_]*)", -- class definition
"^(%s*)def%s+([a-zA-Z_][a-zA-Z0-9_]*)", -- function definition
"^(%s*)async%s+def%s+([a-zA-Z_][a-zA-Z0-9_]*)", -- async function definition
}
for _, pattern in ipairs(patterns_with_indent) do
local indent, symbol = code:match(pattern)
if indent and symbol then
return symbol, indent
end
end
return nil, nil
end
--- If the last line of the input lines of code contains an importable symbol, gets the
--- path for locating that symbol.
---@param lines string[] Lines of source code. Searching happens from last line
---@return string[] symbols List of symbols in hierarchical order, or empty array if no match
M.get_importable_symbol_chain = function(lines)
if #lines == 0 then
return {}
end
-- Start searching from last line
local last_line = lines[#lines]
local symbol, indent = M.find_importable_symbol(last_line)
if not symbol or not indent then
return {}
end
local symbols = { symbol }
local current_indent_level = #indent
-- Walk up the indentation to find parent symbols
for i = #lines - 1, 1, -1 do
symbol, indent = M.find_importable_symbol(lines[i])
if symbol and indent and #indent < current_indent_level then
table.insert(symbols, 1, symbol)
current_indent_level = #indent
-- if current line has no indentation, it's no longer nested so no need to
-- continue searching upwards
if #indent == 0 then
break
end
end
end
return symbols
end
--- Parses a segment of a Python import statement to get the original name of the imported symbol
--- and its alias name (if any).
---@param import_str string Import string segment (e.g. `numpy`, `numpy as np`, `some.path`, `*`)
---@return string|nil original_symbol, string|nil alias_symbol
M.parse_import_symbol = function(import_str)
local original_symbol, alias_symbol = import_str:match("^([%w_%.]+)%s+as%s+([%w_]+)$")
if original_symbol and alias_symbol then
return original_symbol, alias_symbol
end
original_symbol = import_str:match("^([%w_%.]+)$")
return original_symbol, nil
end
--- Creates a map of imported symbols to their full dotted paths.
---
--- Types of supported import symbols:
--- 1. From-imports without alias: `from X import Y, Z`
--- 2. From-imports with alias: `from X import Y as YAlias`
--- 3. Import without alias: `import numpy, pandas`
--- 4. Import with alias: `import numpy as np, user.constants as user_constants`
---
--- NOTE: For (3), we ignore the path if it contains dot(s) (e.g. `import user.services`).
--- This is because it won't create a new symbol where the name equals the word behind the last dot.
--- For example, we still need to reference symbols via `user.services.xxx`, not `services`.
---
---@param lines string[] Lines of source code. Searching happens from last line
---@return table<string, string> symbols_map A map of the symbol name to its dotted path
M.get_imported_symbols_map = function(lines)
---@type table<string, string>
local symbols_map = {}
--- For `from A import B, ...`
local from_import_pattern = "^%s*from%s+([%w%._]+)%s+import%s+(.+)%s*$"
--- For `import A, B, ...`
local import_pattern = "^%s*import%s+(.+)%s*$"
for _, line in ipairs(lines) do
local from_import_module, from_import_symbols_str = line:match(from_import_pattern)
if from_import_module and from_import_symbols_str then
local import_symbol_strings = vim.tbl_map(
string_utils.trim_string,
string_utils.split_string(from_import_symbols_str, ",")
)
for _, symbol_str in ipairs(import_symbol_strings) do
local original_symbol, alias_symbol = M.parse_import_symbol(symbol_str)
local name = alias_symbol or original_symbol -- Use the original name if no alias
if name then
local path = from_import_module .. "." .. original_symbol
symbols_map[name] = path
end
end
end
local import_symbols_str = line:match(import_pattern)
if import_symbols_str then
local import_symbol_strings = vim.tbl_map(
string_utils.trim_string,
string_utils.split_string(import_symbols_str, ",")
)
for _, symbol_str in ipairs(import_symbol_strings) do
local path, alias_symbol = M.parse_import_symbol(symbol_str)
if alias_symbol then
symbols_map[alias_symbol] = path
elseif path and path:find("%.") == nil then
symbols_map[path] = path
end
end
end
end
return symbols_map
end
--- Generates an import statement from a symbol's dotted path. Examples:
--- - `"numpy"` -> `"import numpy"`
--- - `"some.module.foo"` -> `"from some.module import foo"`
---@param dotted_path string Dotted path of a symbol (e.g. `some.module.Symbol`)
---@return string
M.make_import_statement = function(dotted_path)
local last_dot_index = dotted_path:find("%.[^%.]*$")
if last_dot_index then
local module_path = dotted_path:sub(1, last_dot_index - 1)
local symbol_name = dotted_path:sub(last_dot_index + 1)
return "from " .. module_path .. " import " .. symbol_name
end
return "import " .. dotted_path
end
return M