Skip to content

Commit 4810365

Browse files
kauterryfacebook-github-bot
authored andcommitted
Enabled torch.testing._internal.jit_utils.* typechecking. (#44985)
Summary: Fixes #{issue number} Pull Request resolved: #44985 Reviewed By: malfet Differential Revision: D23794444 Pulled By: kauterry fbshipit-source-id: 9893cc91780338a8223904fb574efa77fa3ab2b9
1 parent 9f67176 commit 4810365

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ ignore_errors = True
5656
[mypy-torch.testing._internal.codegen.*]
5757
ignore_errors = True
5858

59-
[mypy-torch.testing._internal.jit_utils.*]
60-
ignore_errors = True
61-
6259
[mypy-torch.testing._internal.autocast_test_lists.*]
6360
ignore_errors = True
6461

torch/_C/__init__.pyi.in

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ...
178178
# Defined in torch/csrc/jit/python/script_init.cpp
179179
ResolutionCallback = Callable[[str], Callable[..., Any]]
180180

181+
def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ...
182+
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
183+
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
184+
def _jit_clear_class_registry() -> None: ...
181185
def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ...
182186
def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
183187
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
@@ -395,7 +399,9 @@ class AggregationType(Enum):
395399
AVG = 1
396400

397401
class FileCheck(object):
398-
# TODO
402+
# TODO (add more FileCheck signature)
403+
def check_source_highlighted(self, highlight: str) -> 'FileCheck': ...
404+
def run(self, test_string: str) -> None: ...
399405
...
400406

401407
# Defined in torch/csrc/jit/python/init.cpp
@@ -416,6 +422,11 @@ class PyTorchFileWriter(object):
416422
def write_end_of_file(self) -> None: ...
417423
...
418424

425+
def _jit_get_inline_everything_mode() -> _bool: ...
426+
def _jit_set_inline_everything_mode(enabled: _bool) -> None: ...
427+
def _jit_pass_dce(Graph) -> None: ...
428+
def _jit_pass_lint(Graph) -> None: ...
429+
419430
# Defined in torch/csrc/jit/python/python_custome_class.cpp
420431
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
421432

torch/testing/_internal/jit_utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from functools import reduce
2525
from itertools import chain
2626
from torch._six import StringIO
27+
from typing import Any, Dict
2728

2829
import inspect
2930
import io
@@ -148,14 +149,14 @@ def extract_files(buffer):
148149
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
149150
files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
150151
# unwrap all the code files into strings
151-
code_files = filter(lambda x: x.endswith('.py'), files)
152-
code_files = map(lambda f: archive.open(f), code_files)
153-
code_files = map(lambda file: "".join([line.decode() for line in file]), code_files)
152+
code_files_str = filter(lambda x: x.endswith('.py'), files)
153+
code_files_stream = map(lambda f: archive.open(f), code_files_str)
154+
code_files = map(lambda file: "".join([line.decode() for line in file]), code_files_stream)
154155

155156
# unpickled all the debug files
156-
debug_files = filter(lambda f: f.endswith('.debug_pkl'), files)
157-
debug_files = map(lambda f: archive.open(f), debug_files)
158-
debug_files = map(lambda f: pickle.load(f), debug_files)
157+
debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files)
158+
debug_files_stream = map(lambda f: archive.open(f), debug_files_str)
159+
debug_files = map(lambda f: pickle.load(f), debug_files_stream)
159160
return code_files, debug_files
160161

161162
# disable the hook while we parse code, otherwise we will re-enter the hook
@@ -336,11 +337,15 @@ def run_pass(self, name, trace):
336337

337338
def get_frame_vars(self, frames_up):
338339
frame = inspect.currentframe()
340+
if not frame:
341+
raise RuntimeError("failed to inspect frame")
339342
i = 0
340343
while i < frames_up + 1:
341344
frame = frame.f_back
345+
if not frame:
346+
raise RuntimeError("failed to get frame")
342347
i += 1
343-
defined_vars = {}
348+
defined_vars: Dict[str, Any] = {}
344349
defined_vars.update(frame.f_locals)
345350
defined_vars.update(frame.f_globals)
346351
return defined_vars
@@ -408,7 +413,7 @@ def checkScript(self,
408413
# outputs
409414

410415
frame = self.get_frame_vars(frames_up)
411-
the_locals = {}
416+
the_locals: Dict[str, Any] = {}
412417
execWrapper(script, glob=frame, loc=the_locals)
413418
frame.update(the_locals)
414419

0 commit comments

Comments
 (0)