Skip to content

Commit cbf2a4f

Browse files
Krovatkinfacebook-github-bot
authored andcommitted
print a warning if a type annotation prefix is invalid according to mypy (#20884)
Summary: This PR adds a check that prints a warning if a type annotation prefix isn't what mypy expects. Pull Request resolved: #20884 Differential Revision: D15511043 Pulled By: Krovatkin fbshipit-source-id: 9038e074807832931faaa5f4e69628f94f51fd72
1 parent a6bb154 commit cbf2a4f

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

test/test_jit.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3373,6 +3373,28 @@ def __init__(self):
33733373
self.assertEqual(D()(v), v + v)
33743374

33753375

3376+
def test_invalid_prefix_annotation(self):
3377+
with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3378+
with self.capture_stdout() as captured:
3379+
@torch.jit.script
3380+
def invalid_prefix_annotation1(a):
3381+
#type: (Int) -> Int # noqa
3382+
return a + 2
3383+
3384+
with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3385+
with self.capture_stdout() as captured:
3386+
@torch.jit.script
3387+
def invalid_prefix_annotation2(a):
3388+
#type : (Int) -> Int # noqa
3389+
return a + 2
3390+
3391+
with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3392+
with self.capture_stdout() as captured:
3393+
@torch.jit.script
3394+
def invalid_prefix_annotation3(a):
3395+
# type: (Int) -> Int
3396+
return a + 2
3397+
33763398
def test_tracing_multiple_methods(self):
33773399
class Net(nn.Module):
33783400
def __init__(self):

torch/jit/annotations.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import ast
33
import inspect
4+
import re
45
import torch
56
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
67
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
@@ -117,8 +118,17 @@ def get_type_line(source):
117118
lines = source.split('\n')
118119
lines = [(line_num, line) for line_num, line in enumerate(lines)]
119120
type_lines = list(filter(lambda line: type_comment in line[1], lines))
121+
lines_with_type = list(filter(lambda line: 'type' in line[1], lines))
122+
120123

121124
if len(type_lines) == 0:
125+
type_pattern = re.compile('#[\t ]*type[\t ]*:')
126+
wrong_type_lines = list(filter(lambda line: type_pattern.search(line[1]), lines))
127+
if len(wrong_type_lines) > 0:
128+
raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
129+
+ " is probably invalid.\nIt must be '# type:'"
130+
+ "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa
131+
+ "\nfor examples")
122132
return None
123133
elif len(type_lines) == 1:
124134
# Only 1 type line, quit now

0 commit comments

Comments
 (0)