|
24 | 24 | from functools import reduce |
25 | 25 | from itertools import chain |
26 | 26 | from torch._six import StringIO |
| 27 | +from typing import Any, Dict |
27 | 28 |
|
28 | 29 | import inspect |
29 | 30 | import io |
@@ -148,14 +149,14 @@ def extract_files(buffer): |
148 | 149 | self.assertEqual(len(set(archive.namelist())), len(archive.namelist())) |
149 | 150 | files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) |
150 | 151 | # unwrap all the code files into strings |
151 | | - code_files = filter(lambda x: x.endswith('.py'), files) |
152 | | - code_files = map(lambda f: archive.open(f), code_files) |
153 | | - code_files = map(lambda file: "".join([line.decode() for line in file]), code_files) |
| 152 | + code_files_str = filter(lambda x: x.endswith('.py'), files) |
| 153 | + code_files_stream = map(lambda f: archive.open(f), code_files_str) |
| 154 | + code_files = map(lambda file: "".join([line.decode() for line in file]), code_files_stream) |
154 | 155 |
|
155 | 156 | # unpickled all the debug files |
156 | | - debug_files = filter(lambda f: f.endswith('.debug_pkl'), files) |
157 | | - debug_files = map(lambda f: archive.open(f), debug_files) |
158 | | - debug_files = map(lambda f: pickle.load(f), debug_files) |
| 157 | + debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files) |
| 158 | + debug_files_stream = map(lambda f: archive.open(f), debug_files_str) |
| 159 | + debug_files = map(lambda f: pickle.load(f), debug_files_stream) |
159 | 160 | return code_files, debug_files |
160 | 161 |
|
161 | 162 | # disable the hook while we parse code, otherwise we will re-enter the hook |
@@ -336,11 +337,15 @@ def run_pass(self, name, trace): |
336 | 337 |
|
337 | 338 | def get_frame_vars(self, frames_up): |
338 | 339 | frame = inspect.currentframe() |
| 340 | + if not frame: |
| 341 | + raise RuntimeError("failed to inspect frame") |
339 | 342 | i = 0 |
340 | 343 | while i < frames_up + 1: |
341 | 344 | frame = frame.f_back |
| 345 | + if not frame: |
| 346 | + raise RuntimeError("failed to get frame") |
342 | 347 | i += 1 |
343 | | - defined_vars = {} |
| 348 | + defined_vars: Dict[str, Any] = {} |
344 | 349 | defined_vars.update(frame.f_locals) |
345 | 350 | defined_vars.update(frame.f_globals) |
346 | 351 | return defined_vars |
@@ -408,7 +413,7 @@ def checkScript(self, |
408 | 413 | # outputs |
409 | 414 |
|
410 | 415 | frame = self.get_frame_vars(frames_up) |
411 | | - the_locals = {} |
| 416 | + the_locals: Dict[str, Any] = {} |
412 | 417 | execWrapper(script, glob=frame, loc=the_locals) |
413 | 418 | frame.update(the_locals) |
414 | 419 |
|
|
0 commit comments