Skip to content

Commit df338f8

Browse files
Dmytro Dzhulgakovfacebook-github-bot
authored andcommitted
Add a wrapper for inspect in JIT to produce better error message (#25415)
Summary: If source code is not available due to packaging (e.g. sources are compiled to .pyc), TorchScript produces very obscure error message. This tries to make it nicer and allow to customize message by overriding _utils_internal. Pull Request resolved: #25415 Test Plan: Really hard to unittest properly. Did one off testing by compiling to .pyc and checking the message. Differential Revision: D17118238 Pulled By: dzhulgakov fbshipit-source-id: 3cbfee0abddc8613000680548bfe0b8ed52a36b0
1 parent 7f3c423 commit df338f8

File tree

5 files changed

+34
-14
lines changed

5 files changed

+34
-14
lines changed

torch/_jit_internal.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010
import torch._C
1111
from torch._six import builtins
12+
from torch._utils_internal import get_source_lines_and_file
1213

1314
# Wrapper functions that can call either of 2 functions depending on a boolean
1415
# argument
@@ -477,11 +478,11 @@ def _get_overloaded_methods(method, mod_class):
477478
if overloads is None:
478479
return None
479480

480-
method_line_no = inspect.getsourcelines(method)[1]
481-
mod_class_fileno = inspect.getsourcelines(mod_class)[1]
482-
mod_end_fileno = mod_class_fileno + len(inspect.getsourcelines(mod_class)[0])
481+
method_line_no = get_source_lines_and_file(method)[1]
482+
mod_class_fileno = get_source_lines_and_file(mod_class)[1]
483+
mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
483484
if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
484-
raise Exception("Overloads are not useable when a module is redaclared within the same file: " + str(method))
485+
raise Exception("Overloads are not useable when a module is redeclared within the same file: " + str(method))
485486
return overloads
486487

487488
try:

torch/_utils_internal.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
22

33
import os
4+
import inspect
45

56
# this arbitrary-looking assortment of functionality is provided here
67
# to have a central place for overrideable behavior. The motivating
@@ -33,5 +34,23 @@ def resolve_library_path(path):
3334
return os.path.realpath(path)
3435

3536

37+
def get_source_lines_and_file(obj):
38+
"""
39+
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
40+
41+
Returns: (sourcelines, file_lino, filename)
42+
"""
43+
filename = None # in case getsourcefile throws
44+
try:
45+
filename = inspect.getsourcefile(obj)
46+
sourcelines, file_lineno = inspect.getsourcelines(obj)
47+
except OSError as e:
48+
raise OSError((
49+
"Can't get source for {}. TorchScript requires source access in order to carry out compilation. " +
50+
"Make sure original .py files are available. Original error: {}").format(filename, e))
51+
52+
return sourcelines, file_lineno, filename
53+
54+
3655
TEST_MASTER_ADDR = '127.0.0.1'
3756
TEST_MASTER_PORT = 29500

torch/jit/annotations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch._C import TensorType, TupleType, FloatType, IntType, \
1010
ListType, StringType, DictType, BoolType, OptionalType, ClassType
1111
from textwrap import dedent
12+
from torch._utils_internal import get_source_lines_and_file
1213

1314

1415
PY35 = sys.version_info >= (3, 5)
@@ -46,7 +47,7 @@ def get_signature(fn):
4647

4748
type_line, source = None, None
4849
try:
49-
source = dedent(inspect.getsource(fn))
50+
source = dedent(''.join(get_source_lines_and_file(fn)[0]))
5051
type_line = get_type_line(source)
5152
except TypeError:
5253
pass
@@ -63,7 +64,7 @@ def get_signature(fn):
6364
# a function takes.
6465
def get_num_params(fn, loc):
6566
try:
66-
source = dedent(inspect.getsource(fn))
67+
source = dedent(''.join(get_source_lines_and_file(fn)[0]))
6768
except (TypeError, IOError):
6869
return None
6970
if source is None:

torch/jit/frontend.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from textwrap import dedent
88
from torch._six import PY2
99
from torch._C._jit_tree_views import *
10+
from torch._utils_internal import get_source_lines_and_file
1011

1112
# Borrowed from cPython implementation
1213
# https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411#
@@ -146,9 +147,8 @@ def get_jit_class_def(cls, self_name):
146147
method_defs = [get_jit_def(method[1],
147148
self_name=self_name) for method in methods]
148149

149-
sourcelines, file_lineno = inspect.getsourcelines(cls)
150+
sourcelines, file_lineno, filename = get_source_lines_and_file(cls)
150151
source = ''.join(sourcelines)
151-
filename = inspect.getsourcefile(cls)
152152
dedent_src = dedent(source)
153153
py_ast = ast.parse(dedent_src)
154154
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
@@ -157,9 +157,8 @@ def get_jit_class_def(cls, self_name):
157157

158158

159159
def get_jit_def(fn, self_name=None):
160-
sourcelines, file_lineno = inspect.getsourcelines(fn)
160+
sourcelines, file_lineno, filename = get_source_lines_and_file(fn)
161161
source = ''.join(sourcelines)
162-
filename = inspect.getsourcefile(fn)
163162
dedent_src = dedent(source)
164163
py_ast = ast.parse(dedent_src)
165164
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):

torch/serialization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import difflib
2-
import inspect
32
import os
43
import io
54
import shutil
@@ -12,6 +11,7 @@
1211
from contextlib import closing, contextmanager
1312
from ._utils import _import_dotted_name
1413
from ._six import string_classes as _string_classes
14+
from torch._utils_internal import get_source_lines_and_file
1515
if sys.version_info[0] == 2:
1616
import cPickle as pickle
1717
else:
@@ -285,8 +285,8 @@ def persistent_id(obj):
285285
serialized_container_types[obj] = True
286286
source_file = source = None
287287
try:
288-
source_file = inspect.getsourcefile(obj)
289-
source = inspect.getsource(obj)
288+
source_lines, _, source_file = get_source_lines_and_file(obj)
289+
source = ''.join(obj)
290290
except Exception: # saving the source is optional, so we can ignore any errors
291291
warnings.warn("Couldn't retrieve source code for container of "
292292
"type " + obj.__name__ + ". It won't be checked "
@@ -449,7 +449,7 @@ def restore_location(storage, location):
449449

450450
def _check_container_source(container_type, source_file, original_source):
451451
try:
452-
current_source = inspect.getsource(container_type)
452+
current_source = ''.join(get_source_lines_and_file(container_type)[0])
453453
except Exception: # saving the source is optional, so we can ignore any errors
454454
warnings.warn("Couldn't retrieve source code for container of "
455455
"type " + container_type.__name__ + ". It won't be checked "

0 commit comments

Comments
 (0)