Skip to content

Commit bc19494

Browse files
voznesenskympytorchmergebot
authored andcommitted
[Dynamo] Symbolic shape guards (#87570)
**Introduces symbolic shape guards into dynamo.** In this PR, we take the existing fake tensor infra and plumbing in dynamo and we start passing a shape_env around. This shape_env does not get plumbed down to middle layers / backend yet - it only collects expressions from frontend invocations at the moment. We then translate these expressions into guards at the point where we take other guards installed throughout dynamo - and add them to check_fn. Part 1 of https://docs.google.com/document/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit# cc @jansel @lezcano @fdrocha @mlazos @soumith @yanboliang @penguinwu @anijain2305 Pull Request resolved: #87570 Approved by: https://github.com/ezyang
1 parent d0e12d1 commit bc19494

File tree

15 files changed

+427
-23
lines changed

15 files changed

+427
-23
lines changed

test/dynamo/test_dynamic_shapes.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,26 @@
33
from torch._dynamo.testing import make_test_cls_with_patches
44

55
try:
6-
from . import test_functions, test_misc, test_modules, test_repros, test_unspec
6+
from . import (
7+
test_export,
8+
test_functions,
9+
test_misc,
10+
test_modules,
11+
test_repros,
12+
test_subgraphs,
13+
test_unspec,
14+
)
715
except ImportError:
16+
import test_export
817
import test_functions
918
import test_misc
1019
import test_modules
1120
import test_repros
21+
import test_subgraphs
1222
import test_unspec
1323

24+
import unittest
25+
1426

1527
def make_dynamic_cls(cls):
1628
return make_test_cls_with_patches(
@@ -23,6 +35,145 @@ def make_dynamic_cls(cls):
2335
DynamicShapesReproTests = make_dynamic_cls(test_repros.ReproTests)
2436
DynamicShapesNNModuleTests = make_dynamic_cls(test_modules.NNModuleTests)
2537
DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests)
38+
DynamicShapesExportTests = make_dynamic_cls(test_export.ExportTests)
39+
DynamicShapesSubGraphTests = make_dynamic_cls(test_subgraphs.SubGraphTests)
40+
41+
42+
# DynamicShapesFunctionTests
43+
unittest.expectedFailure(
44+
DynamicShapesFunctionTests.test_len_tensor_dynamic_shapes
45+
# TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
46+
)
47+
48+
unittest.expectedFailure(
49+
DynamicShapesFunctionTests.test_tensor_len_dynamic_shapes
50+
# TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
51+
)
52+
53+
54+
# DynamicShapesReproTests
55+
unittest.expectedFailure(
56+
DynamicShapesReproTests.test_reformer_eval_dynamic_shapes
57+
# TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
58+
)
59+
60+
unittest.expectedFailure(
61+
DynamicShapesReproTests.test_reformer_train_dynamic_shapes
62+
# TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
63+
)
64+
65+
unittest.expectedFailure(
66+
DynamicShapesReproTests.test_issue175_dynamic_shapes
67+
# TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
68+
)
69+
70+
unittest.expectedFailure(
71+
DynamicShapesReproTests.test_do_paste_mask_dynamic_shapes
72+
# aten.min.dim - couldn't find symbolic meta function/decomposition
73+
)
74+
75+
unittest.expectedFailure(
76+
DynamicShapesReproTests.test_convert_boxes_to_pooler_format_dynamic_shapes
77+
# Could not infer dtype of torch._C.SymIntNode
78+
)
79+
80+
unittest.expectedFailure(
81+
DynamicShapesReproTests.test_ellipsis_dynamic_shapes
82+
# Cannot call sizes() on tensor with symbolic sizes/strides
83+
)
84+
85+
unittest.expectedFailure(
86+
DynamicShapesReproTests.test_hf_t5_forward_dynamic_shapes
87+
# Cannot call sizes() on tensor with symbolic sizes/strides
88+
)
89+
90+
unittest.expectedFailure(
91+
DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes
92+
# Unable to cast Python instance to C++ type
93+
)
94+
95+
unittest.expectedFailure(
96+
DynamicShapesReproTests.test_boxes_len_dynamic_shapes
97+
# Unable to cast Python instance to C++ type
98+
)
99+
100+
unittest.expectedFailure(
101+
DynamicShapesReproTests.test_guard_fail_tensor_bool_dynamic_shapes
102+
# RuntimeError: aten.allclose.default - couldn't find symbolic meta function/decomposition
103+
)
104+
105+
# DynamicShapesMiscTests
106+
unittest.expectedFailure(
107+
DynamicShapesMiscTests.test_unsupported_fake_tensor_dynamic_shapes
108+
# aten.quantize_per_tensor.default - couldn't find symbolic meta function/decomposition
109+
)
110+
unittest.expectedFailure(
111+
DynamicShapesMiscTests.test_module_deepcopy_dynamic_shapes
112+
# aten.squeeze_.dim - couldn't find symbolic meta function/decompositio
113+
)
114+
115+
# DynamicShapesUnspecTests
116+
unittest.expectedFailure(
117+
DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes
118+
# float() argument must be a string or a real number, not 'torch._C.SymIntNode'
119+
)
120+
121+
122+
# DynamicShapesNNModuleTests
123+
unittest.expectedFailure(
124+
DynamicShapesNNModuleTests.test_unsupportedmethod_dynamic_shapes
125+
# aten.squeeze_.dim - couldn't find symbolic meta function/decomposition
126+
)
127+
128+
unittest.expectedFailure(
129+
DynamicShapesNNModuleTests.test_unsupportedmodule_dynamic_shapes
130+
# aten.squeeze_.dim - couldn't find symbolic meta function/decomposition
131+
)
132+
133+
unittest.expectedFailure(
134+
DynamicShapesNNModuleTests.test_self_mutating1_dynamic_shapes
135+
# aten.squeeze_.dim - couldn't find symbolic meta function/decomposition
136+
)
137+
138+
unittest.expectedFailure(
139+
DynamicShapesNNModuleTests.test_call_fn_with_non_const_inputs_safe_dynamic_shapes
140+
# aten.squeeze_.dim - couldn't find symbolic meta function/decomposition
141+
)
142+
143+
144+
# DynamicShapesExportTests
145+
unittest.expectedFailure(
146+
DynamicShapesExportTests.test_export_compare_optimize_with_make_fx_dynamic_shapes
147+
)
148+
unittest.expectedFailure(
149+
DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes
150+
)
151+
unittest.expectedFailure(
152+
DynamicShapesExportTests.test_export_with_constant_list_nonzero_free_function_dynamic_shapes
153+
)
154+
unittest.expectedFailure(
155+
DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes
156+
)
157+
unittest.expectedFailure(
158+
DynamicShapesExportTests.test_export_with_stack_trace_dynamic_shapes
159+
)
160+
unittest.expectedFailure(
161+
DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_dynamic_shapes
162+
)
163+
unittest.expectedFailure(
164+
DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass_dynamic_shapes
165+
)
166+
unittest.expectedFailure(
167+
DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_permute_dynamic_shapes
168+
)
169+
170+
171+
# DynamicShapesSubGraphTests
172+
unittest.expectedFailure(
173+
DynamicShapesSubGraphTests.test_enumerate_not_break_graph_dynamic_shapes
174+
)
175+
unittest.expectedFailure(DynamicShapesSubGraphTests.test_restore_state_dynamic_shapes)
176+
26177

27178
if __name__ == "__main__":
28179
from torch._dynamo.test_case import run_tests

test/dynamo/test_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import itertools
77
import operator
88
from typing import Any
9+
from unittest.mock import patch
910

1011
import torch
1112

test/dynamo/test_repros.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,8 +872,9 @@ def test_longformer_chunk(self):
872872
self.assertTrue(same(opt_fn(input1), correct1))
873873
self.assertTrue(same(opt_fn(input2), correct2))
874874

875-
self.assertEqual(cnt.frame_count, ifdyn(1, 2))
876-
self.assertEqual(cnt.op_count, ifdyn(19, 4))
875+
# Dyn recompiles are due to changes in hidden_state (Should we be guarding on this?)
876+
self.assertEqual(cnt.frame_count, ifdyn(4, 2))
877+
self.assertEqual(cnt.op_count, ifdyn(76, 4))
877878

878879
def test_hf_t5_forward(self):
879880
input = torch.randn([1, 2048, 512])

test/functorch/test_aotdispatch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
11741174
xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function...
11751175
xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel
11761176
xfail('nn.functional.unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
1177+
xfail('unfold', ''), # aten.squeeze_copy.dim - couldn't find symbolic meta function/decomposition
11771178
xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
11781179
xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
11791180
xfail('norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides

test/test_proxy_tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,7 @@ def f(a, b, c, d, e):
12881288
xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition
12891289
xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de...
12901290
xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
1291+
xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
12911292
xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
12921293
xfail('normal', ''), # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition
12931294
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
@@ -1305,6 +1306,7 @@ def f(a, b, c, d, e):
13051306
xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition
13061307
xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition
13071308
xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition
1309+
xfail('repeat_interleave', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
13081310
xfail('reshape_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
13091311
xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
13101312
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
@@ -1354,6 +1356,8 @@ def f(a, b, c, d, e):
13541356
xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
13551357
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
13561358
xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition
1359+
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
1360+
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
13571361
}
13581362
symbolic_tensor_segfaults = {
13591363
skip('nn.functional.batch_norm') # Segfault??
@@ -1454,6 +1458,7 @@ def f(a, b, c, d, e):
14541458
xfail('true_divide', ''), # aten.div_.Tensor - couldn't find symbolic meta function/decomposition
14551459
xfail('trunc', ''), # aten.trunc_.default - couldn't find symbolic meta function/decomposition
14561460
xfail('uniform', ''), # aten.uniform_.default - couldn't find symbolic meta function/decomposition
1461+
xfail('unique', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
14571462
xfail('unsqueeze', ''), # aten.unsqueeze_.default - couldn't find symbolic meta function/decomposition
14581463
xfail('xlogy', ''), # aten.xlogy_.Tensor - couldn't find symbolic meta function/decomposition
14591464
}

torch/_dynamo/convert_frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def transform(instructions, code_options):
417417

418418
assert output.guards is not None
419419
CleanupManager.instance[out_code] = output.cleanups
420-
check_fn = CheckFunctionManager(output.guards, locals, globals)
420+
check_fn = CheckFunctionManager(output, output.guards, locals, globals)
421421

422422
guarded_code = GuardedCode(out_code, check_fn.check_fn)
423423
guard_str = "GUARDS:\n"

0 commit comments

Comments
 (0)