Skip to content

Commit d36b83b

Browse files
committed
Stop immediately specializing common constants 0/1 for plain int
Fixes #128319 Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: c431efc Pull Request resolved: #128327
1 parent 25fcb1c commit d36b83b

File tree

7 files changed

+42
-16
lines changed

7 files changed

+42
-16
lines changed

docs/source/torch.compiler_dynamo_deepdive.rst

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -598,16 +598,19 @@ This is mostly useful for debugging purposes.
598598
0, 1 are always specialized
599599
^^^^^^^^^^^^^^^^^^^^^^^^^^^
600600

601-
Regardless of whether we mark a dimension as dynamic, or we have traced
602-
an integer as dynamic, if we pass an input where that dimension is 0 or
603-
1, Dynamo will trace it as non-dynamic and it will generate a specific
604-
graph for it. This is the reason why in the example above we find guards
605-
of the form ``2 <= L['a'].size()[0]``.
601+
Regardless of whether we mark a dimension as dynamic, if we pass an input
602+
where that dimension is 0 or 1, Dynamo will trace it as non-dynamic and it
603+
will generate a specific graph for it. This is the reason why in the example
604+
above we find guards of the form ``2 <= L['a'].size()[0]``.
606605

607606
There are several reasons for this choice. There are two particularly
608607
important - A tensor is empty if and only if any of its dimensions is
609608
zero - A tensor can only be contiguous if one of the strides is one
610609

610+
This policy decision does NOT apply to plain Python ints; if we think a Python
611+
int should be compiled dynamically, we won't specialize them by default;
612+
instead, whether or not it gets specialized depends on its usage.
613+
611614
Duck shaping
612615
^^^^^^^^^^^^
613616

test/dynamo/test_backward_higher_order_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,20 @@ def fn(x, y):
192192
actual,
193193
"""\
194194
class GraphModule(torch.nn.Module):
195-
def forward(self, L_inputs_ : list):
195+
def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"):
196196
l_inputs_ = L_inputs_
197+
l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter
197198
198199
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
199200
200201
new_grad: "f32[s0]" = torch.clone(getitem)
201202
203+
add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None
204+
202205
result: "f32[s0]" = getitem * getitem; getitem = None
203206
204207
new_grad_1: "f32[s0]" = torch.clone(result); result = None
205-
return (new_grad, new_grad_1)
208+
return (new_grad, new_grad_1, add)
206209
""",
207210
)
208211

test/dynamo/test_misc.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2651,9 +2651,12 @@ def fn(x: int, y: torch.Tensor):
26512651
ref = fn(x, y)
26522652
res = opt_fn(x, y)
26532653
self.assertEqual(ref, res)
2654-
# It's all traced once with x = 1, x = 2 and then x = ks0
2655-
# For dynamic it's x=1 and x=ks0
2656-
self.assertEqual(cnts.frame_count, ifdynstaticdefault(3, 2))
2654+
# It's all traced once with x = 1 and then x = ks0
2655+
# For dynamic it's x=ks0
2656+
if torch._dynamo.config.assume_static_by_default:
2657+
self.assertExpectedInline(str(cnts.frame_count), """2""")
2658+
else:
2659+
self.assertExpectedInline(str(cnts.frame_count), """2""")
26572660

26582661
def test_numpy_with_builtin_type(self):
26592662
x = np.random.rand(5)

test/dynamo/test_modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,10 @@ def test_self_mutating1(self):
12381238
out4 = [opt_m4(i), opt_m4(i), opt_m4(i)]
12391239
self.assertTrue(torch._dynamo.testing.same(out2, out3))
12401240
self.assertTrue(torch._dynamo.testing.same(out2, out4))
1241-
self.assertEqual(cnt.frame_count, 3)
1241+
if torch._dynamo.config.assume_static_by_default:
1242+
self.assertExpectedInline(cnt.frame_count, """2""")
1243+
else:
1244+
self.assertExpectedInline(cnt.frame_count, """1""")
12421245

12431246
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
12441247
def test_generation_tag(self):

test/dynamo/test_repros.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4350,7 +4350,7 @@ def fn(x, y):
43504350
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
43514351
x = torch.rand([2, 2])
43524352
opt_fn(x, x)
4353-
self.assertEqual(cnt.frame_count, 1)
4353+
self.assertExpectedInline(cnt.frame_count, """1""")
43544354

43554355
@torch._dynamo.config.patch(capture_scalar_outputs=True)
43564356
def test_unbacked_arange_in_bounds(self):
@@ -4419,7 +4419,7 @@ def fn(x, y):
44194419
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
44204420
x = torch.rand([2, 2])
44214421
self.assertEqual(opt_fn(x, [5]), fn(x, [5]))
4422-
self.assertEqual(cnt.frame_count, 1)
4422+
self.assertExpectedInline(cnt.frame_count, """1""")
44234423

44244424
def test_user_ctor_ctx_manager_custom_init_graph_break(self):
44254425
counter = [0]
@@ -4447,7 +4447,10 @@ def fn(x, counter):
44474447
for i in range(0, 10):
44484448
opt_fn(x, counter)
44494449
self.assertEqual(counter[0], 12)
4450-
self.assertEqual(cnt.frame_count, torch._dynamo.utils.ifdynstaticdefault(3, 2))
4450+
if torch._dynamo.config.assume_static_by_default:
4451+
self.assertExpectedInline(cnt.frame_count, """2""")
4452+
else:
4453+
self.assertExpectedInline(cnt.frame_count, """1""")
44514454

44524455
@unittest.expectedFailure
44534456
def test_many_overlapping_inputs_does_not_explode_guards(self):

test/test_linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2832,6 +2832,7 @@ def test_inverse_many_batches_helper(torch_inverse, b, n):
28322832
@skipCPUIfNoLapack
28332833
@onlyNativeDeviceTypes # TODO: XLA doesn't raise exception
28342834
@dtypes(*floating_and_complex_types())
2835+
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
28352836
def test_inverse_errors(self, device, dtype):
28362837
# inverse expects batches of square matrices as input
28372838
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
@@ -2976,6 +2977,7 @@ def test_pinv_errors_and_warnings(self, device, dtype):
29762977
@skipCUDAIfNoMagmaAndNoCusolver
29772978
@skipCPUIfNoLapack
29782979
@dtypes(*floating_and_complex_types())
2980+
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
29792981
def test_inv_errors_and_warnings(self, device, dtype):
29802982
# inv expects batches of square matrices as input
29812983
a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device)

torch/_dynamo/variables/builder.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import weakref
1818
from typing import Any, List, NamedTuple, Optional, Union
1919

20+
from torch._utils_internal import justknobs_check
21+
2022
from torch.utils._sympy.value_ranges import ValueRanges
2123

2224
try:
@@ -1248,15 +1250,22 @@ def wrap_literal(self, value):
12481250
# unspecializing int by default, but still
12491251
# specialize for the following conditions
12501252
if not TracingContext.get().force_unspec_int_unbacked_size_like and (
1251-
value in self._common_constants()
12521253
# Assume integers from global variables want to be specialized
1253-
or not self.source.guard_source().is_local()
1254+
not self.source.guard_source().is_local()
12541255
# Assume that integers that came from NN modules want to be
12551256
# specialized (as we don't expect users to be changing the
12561257
# NN modules on the fly)
12571258
or self.source.guard_source().is_nn_module()
12581259
or is_from_defaults(self.source)
12591260
or is_cell_contents(self.source)
1261+
# TODO: Delete this condition when rollout is done. NB: this
1262+
# condition never evaluates True in open source
1263+
or (
1264+
not justknobs_check(
1265+
"pytorch/dynamo:enable_unspecialize_zero_one_plain_int"
1266+
)
1267+
and value in self._common_constants()
1268+
)
12601269
):
12611270
self.install_guards(GuardBuilder.CONSTANT_MATCH)
12621271
return ConstantVariable.create(value=value, source=self.source)

0 commit comments

Comments
 (0)