@@ -2992,6 +2992,64 @@ def invalid_prefix_annotation3(a):
29922992 # type: (Int) -> Int
29932993 return a + 2
29942994
2995+ def test_interpreter_fuzz(self):
2996+ # This test generates random tree-like programs to fuzz test
2997+ # that the interpreter does not have a bug in its stack manipulation
2998+ # code. An assert in that code ensures individual operators are
2999+ # not reordered.
3000+ templates = [
3001+ "torch.rand(3, 4)",
3002+ "({} + {})",
3003+ "-{}",
3004+ "({} * {})",
3005+ "torch.tanh({})",
3006+ "VAR {}",
3007+ ]
3008+
3009+ def gen_code():
3010+ src_lines = ['def f():']
3011+ exprs = []
3012+ n_variables = 0
3013+
3014+ def get_expr(idx):
3015+ elem = exprs[idx]
3016+ exprs[idx] = exprs[-1]
3017+ exprs.pop()
3018+ return elem
3019+
3020+ def select_expr_or_var():
3021+ idx = random.randrange(0, len(exprs) + n_variables)
3022+ if idx < len(exprs):
3023+ return get_expr(idx)
3024+ else:
3025+ return 'v{}'.format(idx - len(exprs))
3026+
3027+ for i in range(50):
3028+ n = None
3029+ while n is None or n > len(exprs) + n_variables:
3030+ template = random.choice(templates)
3031+ n = template.count('{}')
3032+
3033+ if 'VAR' in template:
3034+ src_lines.append(' v{} = {}'.format(n_variables, select_expr_or_var()))
3035+ n_variables += 1
3036+ else:
3037+ exprs.append(template.format(*(select_expr_or_var() for _ in range(n))))
3038+
3039+ src_lines.append(' return ({})\n'.format(''.join('v{},'.format(i) for i in range(n_variables))))
3040+ return '\n'.join(src_lines)
3041+
3042+ for i in range(100):
3043+ g = {'torch': torch}
3044+ code = gen_code()
3045+ torch._six.exec_(code, g, None)
3046+ cu = torch.jit.CompilationUnit(code)
3047+ with freeze_rng_state():
3048+ o1 = g['f']()
3049+ with freeze_rng_state():
3050+ o2 = cu.f()
3051+ self.assertEqual(o1, o2)
3052+
29953053 def test_tracing_multiple_methods(self):
29963054 class Net(nn.Module):
29973055 def __init__(self):
0 commit comments