Skip to content

Commit a26b4df

Browse files
authored
Add common plugin API available for all hooks (python#6044)
There is a problem with `get_base_class_hook()`: only the full name of the base class is passed to it. This causes problems when this hook needs to interact with other hooks. In normal mode this can be solved by using a global plugin state, but this doesn't work in incremental mode, since this plugin state is not stored in cache. There is however already a field that is stored in cache between incremental runs and is free to use by plugins: `TypeInfo.metadata`. The problem however is that plugin hooks get only full names and can't access the corresponding `TypeInfo`. This is a common enough problem that we should add an API to make the nodes referred by full names accessible to the plugin hooks. This is done by adding a common lookup method to the `Plugin` class.
1 parent 84c2d7f commit a26b4df

File tree

8 files changed

+199
-38
lines changed

8 files changed

+199
-38
lines changed

mypy/build.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,6 +2386,7 @@ def load_graph(sources: List[BuildSource], manager: BuildManager,
23862386
if dep in st.suppressed:
23872387
st.suppressed.remove(dep)
23882388
st.dependencies.append(dep)
2389+
manager.plugin.set_modules(manager.modules)
23892390
return graph
23902391

23912392

mypy/fixup.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
TypeType, NOT_READY
1414
)
1515
from mypy.visitor import NodeVisitor
16+
from mypy.lookup import lookup_fully_qualified
1617

1718

1819
# N.B: we do a quick_and_dirty fixup in both quick_and_dirty mode and
@@ -257,40 +258,7 @@ def lookup_qualified(modules: Dict[str, MypyFile], name: str,
257258

258259
def lookup_qualified_stnode(modules: Dict[str, MypyFile], name: str,
259260
quick_and_dirty: bool) -> Optional[SymbolTableNode]:
260-
head = name
261-
rest = []
262-
while True:
263-
if '.' not in head:
264-
if not quick_and_dirty:
265-
assert '.' in head, "Cannot find %s" % (name,)
266-
return None
267-
head, tail = head.rsplit('.', 1)
268-
rest.append(tail)
269-
mod = modules.get(head)
270-
if mod is not None:
271-
break
272-
names = mod.names
273-
while True:
274-
if not rest:
275-
if not quick_and_dirty:
276-
assert rest, "Cannot find %s" % (name,)
277-
return None
278-
key = rest.pop()
279-
if key not in names:
280-
if not quick_and_dirty:
281-
assert key in names, "Cannot find %s for %s" % (key, name)
282-
return None
283-
stnode = names[key]
284-
if not rest:
285-
return stnode
286-
node = stnode.node
287-
# In fine-grained mode, could be a cross-reference to a deleted module
288-
# or a Var made up for a missing module.
289-
if not isinstance(node, TypeInfo):
290-
if not quick_and_dirty:
291-
assert node, "Cannot find %s" % (name,)
292-
return None
293-
names = node.names
261+
return lookup_fully_qualified(name, modules, raise_on_missing=not quick_and_dirty)
294262

295263

296264
def stale_info(modules: Dict[str, MypyFile]) -> TypeInfo:

mypy/interpreted_plugin.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Hack for handling non-mypyc compiled plugins with a mypyc-compiled mypy"""
22

3-
from typing import Optional, Callable, Any
3+
from typing import Optional, Callable, Any, Dict
44
from mypy.options import Options
55
from mypy.types import Type, CallableType
6-
6+
from mypy.nodes import SymbolTableNode, MypyFile
7+
from mypy.lookup import lookup_fully_qualified
78

89
MYPY = False
910
if MYPY:
@@ -29,6 +30,14 @@ def __new__(cls, *args: Any, **kwargs: Any) -> 'mypy.plugin.Plugin':
2930
def __init__(self, options: Options) -> None:
3031
self.options = options
3132
self.python_version = options.python_version
33+
self._modules = None # type: Optional[Dict[str, MypyFile]]
34+
35+
def set_modules(self, modules: Dict[str, MypyFile]) -> None:
36+
self._modules = modules
37+
38+
def lookup_fully_qualified(self, fullname: str) -> Optional[SymbolTableNode]:
39+
assert self._modules is not None
40+
return lookup_fully_qualified(fullname, self._modules)
3241

3342
def get_type_analyze_hook(self, fullname: str
3443
) -> Optional[Callable[['mypy.plugin.AnalyzeTypeContext'], Type]]:

mypy/lookup.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
This is a module for various lookup functions:
3+
functions that will find a semantic node by its name.
4+
"""
5+
6+
from mypy.nodes import MypyFile, SymbolTableNode, TypeInfo
7+
from typing import Dict, Optional
8+
9+
# TODO: gradually move existing lookup functions to this module.
10+
11+
12+
def lookup_fully_qualified(name: str, modules: Dict[str, MypyFile],
13+
raise_on_missing: bool = False) -> Optional[SymbolTableNode]:
14+
"""Find a symbol using it fully qualified name.
15+
16+
The algorithm has two steps: first we try splitting the name on '.' to find
17+
the module, then iteratively look for each next chunk after a '.' (e.g. for
18+
nested classes).
19+
20+
This function should *not* be used to find a module. Those should be looked
21+
in the modules dictionary.
22+
"""
23+
head = name
24+
rest = []
25+
# 1. Find a module tree in modules dictionary.
26+
while True:
27+
if '.' not in head:
28+
if raise_on_missing:
29+
assert '.' in head, "Cannot find module for %s" % (name,)
30+
return None
31+
head, tail = head.rsplit('.', maxsplit=1)
32+
rest.append(tail)
33+
mod = modules.get(head)
34+
if mod is not None:
35+
break
36+
names = mod.names
37+
# 2. Find the symbol in the module tree.
38+
if not rest:
39+
# Looks like a module, don't use this to avoid confusions.
40+
if raise_on_missing:
41+
assert rest, "Cannot find %s, got a module symbol" % (name,)
42+
return None
43+
while True:
44+
key = rest.pop()
45+
if key not in names:
46+
if raise_on_missing:
47+
assert key in names, "Cannot find %s for %s" % (key, name)
48+
return None
49+
stnode = names[key]
50+
if not rest:
51+
return stnode
52+
node = stnode.node
53+
# In fine-grained mode, could be a cross-reference to a deleted module
54+
# or a Var made up for a missing module.
55+
if not isinstance(node, TypeInfo):
56+
if raise_on_missing:
57+
assert node, "Cannot find %s" % (name,)
58+
return None
59+
names = node.names

mypy/plugin.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,14 @@
4343
from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar, Dict
4444
from mypy_extensions import trait
4545

46-
from mypy.nodes import Expression, Context, ClassDef, SymbolTableNode, MypyFile, CallExpr
46+
from mypy.nodes import (
47+
Expression, Context, ClassDef, SymbolTableNode, MypyFile, CallExpr
48+
)
4749
from mypy.tvar_scope import TypeVarScope
4850
from mypy.types import Type, Instance, CallableType, TypeList, UnboundType
4951
from mypy.messages import MessageBuilder
5052
from mypy.options import Options
53+
from mypy.lookup import lookup_fully_qualified
5154
import mypy.interpreted_plugin
5255

5356

@@ -96,6 +99,28 @@ def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type],
9699
('api', TypeAnalyzerPluginInterface)])
97100

98101

102+
@trait
103+
class CommonPluginApi:
104+
"""
105+
A common plugin API (shared between semantic analysis and type checking phases)
106+
that all plugin hooks get independently of the context.
107+
"""
108+
109+
# Global mypy options.
110+
# Per-file options can be only accessed on various
111+
# XxxPluginInterface classes.
112+
options = None # type: Options
113+
114+
@abstractmethod
115+
def lookup_fully_qualified(self, fullname: str) -> Optional[SymbolTableNode]:
116+
"""Lookup a symbol by its full name (including module).
117+
118+
This lookup function available for all plugins. Return None if a name
119+
is not found. This function doesn't support lookup from current scope.
120+
Use SemanticAnalyzerPluginInterface.lookup_qualified() for this."""
121+
raise NotImplementedError
122+
123+
99124
@trait
100125
class CheckerPluginInterface:
101126
"""Interface for accessing type checker functionality in plugins.
@@ -286,7 +311,7 @@ def qualified_name(self, n: str) -> str:
286311
])
287312

288313

289-
class Plugin:
314+
class Plugin(CommonPluginApi):
290315
"""Base class of all type checker plugins.
291316
292317
This defines a no-op plugin. Subclasses can override some methods to
@@ -303,6 +328,17 @@ class Plugin:
303328
def __init__(self, options: Options) -> None:
304329
self.options = options
305330
self.python_version = options.python_version
331+
# This can't be set in __init__ because it is executed too soon in build.py.
332+
# Therefore, build.py *must* set it later before graph processing starts
333+
# by calling set_modules().
334+
self._modules = None # type: Optional[Dict[str, MypyFile]]
335+
336+
def set_modules(self, modules: Dict[str, MypyFile]) -> None:
337+
self._modules = modules
338+
339+
def lookup_fully_qualified(self, fullname: str) -> Optional[SymbolTableNode]:
340+
assert self._modules is not None
341+
return lookup_fully_qualified(fullname, self._modules)
306342

307343
def get_type_analyze_hook(self, fullname: str
308344
) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
@@ -479,6 +515,12 @@ def __init__(self, plugin: mypy.interpreted_plugin.InterpretedPlugin) -> None:
479515
super().__init__(plugin.options)
480516
self.plugin = plugin
481517

518+
def set_modules(self, modules: Dict[str, MypyFile]) -> None:
519+
self.plugin.set_modules(modules)
520+
521+
def lookup_fully_qualified(self, fullname: str) -> Optional[SymbolTableNode]:
522+
return self.plugin.lookup_fully_qualified(fullname)
523+
482524
def get_type_analyze_hook(self, fullname: str
483525
) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
484526
return self.plugin.get_type_analyze_hook(fullname)
@@ -540,6 +582,10 @@ def __init__(self, options: Options, plugins: List[Plugin]) -> None:
540582
super().__init__(options)
541583
self._plugins = plugins
542584

585+
def set_modules(self, modules: Dict[str, MypyFile]) -> None:
586+
for plugin in self._plugins:
587+
plugin.set_modules(modules)
588+
543589
def get_type_analyze_hook(self, fullname: str
544590
) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
545591
return self._find_hook(lambda plugin: plugin.get_type_analyze_hook(fullname))

mypy/semanal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3504,6 +3504,9 @@ def lookup_fully_qualified_or_none(self, fullname: str) -> Optional[SymbolTableN
35043504
35053505
Note that this can't be used for names nested in class namespaces.
35063506
"""
3507+
# TODO: unify/clean-up/simplify lookup methods, see #4157.
3508+
# TODO: support nested classes (but consider performance impact,
3509+
# we might keep the module level only lookup for thing like 'builtins.int').
35073510
assert '.' in fullname
35083511
module, name = fullname.rsplit('.', maxsplit=1)
35093512
if module not in self.modules:

test-data/unit/check-custom-plugin.test

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,34 @@ class Instr(Generic[T]): ...
331331
[file mypy.ini]
332332
[[mypy]
333333
plugins=<ROOT>/test-data/unit/plugins/dyn_class.py
334+
335+
[case testBaseClassPluginHookWorksIncremental]
336+
# flags: --config-file tmp/mypy.ini
337+
import a
338+
339+
[file a.py]
340+
from base import Base
341+
class C(Base): ...
342+
343+
[file a.py.2]
344+
from base import Base
345+
class C(Base): ...
346+
reveal_type(C().__magic__)
347+
Base.__magic__
348+
349+
[file base.py]
350+
from lib import declarative_base
351+
Base = declarative_base()
352+
353+
[file lib.py]
354+
from typing import Any
355+
def declarative_base() -> Any: ...
356+
357+
[file mypy.ini]
358+
[[mypy]
359+
python_version=3.6
360+
plugins=<ROOT>/test-data/unit/plugins/common_api_incremental.py
361+
[out]
362+
[out2]
363+
tmp/a.py:3: error: Revealed type is 'builtins.str'
364+
tmp/a.py:4: error: "Type[Base]" has no attribute "__magic__"
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from mypy.plugin import Plugin
2+
from mypy.nodes import (
3+
ClassDef, Block, TypeInfo, SymbolTable, SymbolTableNode, MDEF, GDEF, Var
4+
)
5+
6+
7+
class DynPlugin(Plugin):
8+
def get_dynamic_class_hook(self, fullname):
9+
if fullname == 'lib.declarative_base':
10+
return add_info_hook
11+
return None
12+
13+
def get_base_class_hook(self, fullname: str):
14+
sym = self.lookup_fully_qualified(fullname)
15+
if sym and isinstance(sym.node, TypeInfo):
16+
if sym.node.metadata.get('magic'):
17+
return add_magic_hook
18+
return None
19+
20+
21+
def add_info_hook(ctx) -> None:
22+
class_def = ClassDef(ctx.name, Block([]))
23+
class_def.fullname = ctx.api.qualified_name(ctx.name)
24+
25+
info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id)
26+
class_def.info = info
27+
obj = ctx.api.builtin_type('builtins.object')
28+
info.mro = [info, obj.type]
29+
info.bases = [obj]
30+
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
31+
info.metadata['magic'] = True
32+
33+
34+
def add_magic_hook(ctx) -> None:
35+
info = ctx.cls.info
36+
str_type = ctx.api.named_type_or_none('builtins.str', [])
37+
assert str_type is not None
38+
var = Var('__magic__', str_type)
39+
var.info = info
40+
info.names['__magic__'] = SymbolTableNode(MDEF, var)
41+
42+
43+
def plugin(version):
44+
return DynPlugin

0 commit comments

Comments
 (0)