Skip to content

Commit 38afbba

Browse files
author
Mike Ruberry
committed
updates per review
1 parent 7e562e0 commit 38afbba

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

test/test_jit.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \
5757
RUN_CUDA
5858
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, nn_functional_tests, get_script_args, \
59-
EXCLUDE_SCRIPT, additional_module_tests, EXCLUDE_SCRIPT_MODULES, \
59+
get_call, script_template, EXCLUDE_SCRIPT, additional_module_tests, EXCLUDE_SCRIPT_MODULES, \
6060
get_nn_module_name_from_kwargs, script_method_template, create_traced_fn
6161

6262
from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests
@@ -15165,6 +15165,14 @@ def parameter_script(x: torch.nn.Parameter):
1516515165
'test_nn_max_pool1d_with_indices',
1516615166
}
1516715167

15168+
def check_alias_annotation(method_name, args, kwargs):
15169+
formals, tensors, actuals = get_script_args(args)
15170+
call = get_call(method_name, 'method', actuals, kwargs)
15171+
script = script_template.format(', '.join(formals), call)
15172+
CU = torch.jit.CompilationUnit(script)
15173+
torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name)
15174+
15175+
1516815176
def check_output_types(self, func, ref_outputs, args, kwargs):
1516915177
graph = getattr(func, 'last_graph', None)
1517015178
types = [o.type() for o in graph.outputs()]
@@ -15427,6 +15435,10 @@ def fn(*inputs, **kwargs):
1542715435
fn, f_args_variable, kwargs_variable,
1542815436
check_types=check_types)
1542915437

15438+
# alias annotation testing
15439+
if is_inplace and test_name not in EXCLUDE_SCRIPT:
15440+
check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable)
15441+
1543015442
check(name)
1543115443
inplace_name = name + '_'
1543215444
# can't broadcast inplace to left hand side

test/test_linalg.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,8 @@ def test_outer(self, device, dtype):
2020
a = torch.randn(50, device=device, dtype=dtype)
2121
b = torch.randn(50, device=device, dtype=dtype)
2222

23-
def _fn(a, b):
24-
return torch.outer(a, b)
25-
2623
ops = (torch.ger, torch.Tensor.ger,
27-
torch.outer, torch.Tensor.outer,
28-
torch.jit.script(_fn))
24+
torch.outer, torch.Tensor.outer)
2925

3026
expected = np.outer(a.cpu().numpy(), b.cpu().numpy())
3127
for op in ops:
@@ -44,11 +40,9 @@ def test_det(self, device, dtype):
4440
torch.randn((3, 52, 52), device=device, dtype=dtype),
4541
torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
4642

47-
def _fn(t):
48-
return torch.linalg.det(t)
4943

5044
ops = (torch.det, torch.Tensor.det,
51-
torch.linalg.det, torch.jit.script(_fn))
45+
torch.linalg.det)
5246
for t in tensors:
5347
expected = np.linalg.det(t.cpu().numpy())
5448
for op in ops:
@@ -57,12 +51,9 @@ def _fn(t):
5751

5852
# NOTE: det requires a 2D+ tensor
5953
t = torch.randn(1, device=device, dtype=dtype)
60-
for op in ops:
61-
try:
62-
op(t)
63-
except Exception as e:
64-
continue
65-
self.assertTrue(False, msg="Failed to throw error on 1D tensor!")
54+
with self.assertRaises(IndexError):
55+
op(t)
56+
6657

6758
instantiate_device_type_tests(TestLinalg, globals())
6859

test/test_op_normalization.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def _test(self, device, info=info):
5858
op = info.alias_op
5959
is_inplace = info.alias_name.endswith('_')
6060

61-
# NOTE: this workaround necessary to satisfy the JIT, which
62-
# does allow Python aliasing or direct calling of torch.Tensor
63-
# methods when scripting
61+
# Checks that scripting converts aliases
62+
# NOTE: the code to test scripting must be generated since
63+
# scripting does not support splatting args or directly
64+
# calling torch.Tensor methods. The following
65+
# splats args after the first tensor by inlining them as constants.
6466
if is_inplace:
6567
fn_template = '''
6668
def _fn(t):
@@ -69,9 +71,6 @@ def _fn(t):
6971
arg_string = ', '.join((str(arg) for arg in info.args))
7072
script = fn_template.format(alias_name=info.alias_name, args=arg_string)
7173
else:
72-
# Checks that scripting converts aliases
73-
# NOTE: scripting doesn't support splatting args, so this
74-
# generates the script with the args already splatted (as constants)
7574
fn_template = '''
7675
def _fn(t):
7776
return op(t{args})

0 commit comments

Comments
 (0)