|
56 | 56 | execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ |
57 | 57 | RUN_CUDA |
58 | 58 | from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, nn_functional_tests, get_script_args, \ |
59 | | - EXCLUDE_SCRIPT, additional_module_tests, EXCLUDE_SCRIPT_MODULES, \ |
| 59 | + get_call, script_template, EXCLUDE_SCRIPT, additional_module_tests, EXCLUDE_SCRIPT_MODULES, \ |
60 | 60 | get_nn_module_name_from_kwargs, script_method_template, create_traced_fn |
61 | 61 |
|
62 | 62 | from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests |
@@ -15165,6 +15165,14 @@ def parameter_script(x: torch.nn.Parameter): |
15165 | 15165 | 'test_nn_max_pool1d_with_indices', |
15166 | 15166 | } |
15167 | 15167 |
|
| 15168 | +def check_alias_annotation(method_name, args, kwargs): |
| 15169 | + formals, tensors, actuals = get_script_args(args) |
| 15170 | + call = get_call(method_name, 'method', actuals, kwargs) |
| 15171 | + script = script_template.format(', '.join(formals), call) |
| 15172 | + CU = torch.jit.CompilationUnit(script) |
| 15173 | + torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name) |
| 15174 | + |
| 15175 | + |
15168 | 15176 | def check_output_types(self, func, ref_outputs, args, kwargs): |
15169 | 15177 | graph = getattr(func, 'last_graph', None) |
15170 | 15178 | types = [o.type() for o in graph.outputs()] |
@@ -15427,6 +15435,10 @@ def fn(*inputs, **kwargs): |
15427 | 15435 | fn, f_args_variable, kwargs_variable, |
15428 | 15436 | check_types=check_types) |
15429 | 15437 |
|
| 15438 | + # alias annotation testing |
| 15439 | + if is_inplace and test_name not in EXCLUDE_SCRIPT: |
| 15440 | + check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable) |
| 15441 | + |
15430 | 15442 | check(name) |
15431 | 15443 | inplace_name = name + '_' |
15432 | 15444 | # can't broadcast inplace to left hand side |
|
0 commit comments