Skip to content

Commit dde2795

Browse files
zdevitofacebook-github-bot
authored andcommitted
Reduce number of stack manipulation instructions in interpreter. (#21240)
Summary: Pull Request resolved: #21240 ghimport-source-id: 5e9cbe8 Reviewed By: jamesr66a Differential Revision: D15590900 Pulled By: zdevito fbshipit-source-id: 98829979feba23685f0ba98ba3cb840157f7259a
1 parent c53e4d0 commit dde2795

File tree

2 files changed

+236
-48
lines changed

2 files changed

+236
-48
lines changed

test/test_jit.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,6 +2992,64 @@ def invalid_prefix_annotation3(a):
29922992
# type: (Int) -> Int
29932993
return a + 2
29942994

2995+
def test_interpreter_fuzz(self):
2996+
# This test generates random tree-like programs to fuzz test
2997+
# that the interpreter does not have a bug in its stack manipulation
2998+
# code. An assert in that code ensures individual operators are
2999+
# not reordered.
3000+
templates = [
3001+
"torch.rand(3, 4)",
3002+
"({} + {})",
3003+
"-{}",
3004+
"({} * {})",
3005+
"torch.tanh({})",
3006+
"VAR {}",
3007+
]
3008+
3009+
def gen_code():
3010+
src_lines = ['def f():']
3011+
exprs = []
3012+
n_variables = 0
3013+
3014+
def get_expr(idx):
3015+
elem = exprs[idx]
3016+
exprs[idx] = exprs[-1]
3017+
exprs.pop()
3018+
return elem
3019+
3020+
def select_expr_or_var():
3021+
idx = random.randrange(0, len(exprs) + n_variables)
3022+
if idx < len(exprs):
3023+
return get_expr(idx)
3024+
else:
3025+
return 'v{}'.format(idx - len(exprs))
3026+
3027+
for i in range(50):
3028+
n = None
3029+
while n is None or n > len(exprs) + n_variables:
3030+
template = random.choice(templates)
3031+
n = template.count('{}')
3032+
3033+
if 'VAR' in template:
3034+
src_lines.append(' v{} = {}'.format(n_variables, select_expr_or_var()))
3035+
n_variables += 1
3036+
else:
3037+
exprs.append(template.format(*(select_expr_or_var() for _ in range(n))))
3038+
3039+
src_lines.append(' return ({})\n'.format(''.join('v{},'.format(i) for i in range(n_variables))))
3040+
return '\n'.join(src_lines)
3041+
3042+
for i in range(100):
3043+
g = {'torch': torch}
3044+
code = gen_code()
3045+
torch._six.exec_(code, g, None)
3046+
cu = torch.jit.CompilationUnit(code)
3047+
with freeze_rng_state():
3048+
o1 = g['f']()
3049+
with freeze_rng_state():
3050+
o2 = cu.f()
3051+
self.assertEqual(o1, o2)
3052+
29953053
def test_tracing_multiple_methods(self):
29963054
class Net(nn.Module):
29973055
def __init__(self):

0 commit comments

Comments
 (0)