Skip to content

Commit d7680a5

Browse files
ezyangpytorchmergebot
authored andcommitted
Bug fixes for disabling 0/1 specialization on plain int (#129961)
These bug fixes will be exercised in #128327 but I separate them from the actual policy change (which is more risky) Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #129961 Approved by: https://github.com/lezcano
1 parent 29ffa20 commit d7680a5

File tree

5 files changed

+42
-21
lines changed

5 files changed

+42
-21
lines changed

test/test_dynamic_shapes.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,18 +1256,12 @@ def get_constant_bool(self, val):
12561256
def test_symnode_hashing(self):
12571257
shape_env = ShapeEnv()
12581258

1259-
# SymInt, SymBool, SymFloat are unhashable
1260-
unhashable = (
1261-
create_symint(shape_env, 3),
1262-
create_symbool(shape_env, True),
1263-
# We should be passing in float here, but create_symbol currently
1264-
# only supports int
1265-
create_symfloat(shape_env, 3.0),
1266-
)
1267-
1268-
for x in unhashable:
1269-
with self.assertRaisesRegex(TypeError, "unhashable"):
1270-
hash(x)
1259+
# These all trigger specialization when hashed
1260+
hash(create_symint(shape_env, 3))
1261+
hash(create_symbool(shape_env, True))
1262+
# We should be passing in float here, but create_symbol currently
1263+
# only supports int
1264+
hash(create_symfloat(shape_env, 3.0))
12711265

12721266
# NestedInt (SymInt), constant SymBool, SymNode are hashable
12731267
j1 = torch._C._get_nested_int(1, 1)

torch/__init__.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,8 @@ def __hash__(self) -> builtins.int:
511511
if self.node.is_nested_int():
512512
return hash(self.node.nested_int())
513513
else:
514-
# We could support constant SymInts as well, but not doing it for now
515-
raise TypeError("unhashable type: non-nested SymInt")
514+
# Force specialization
515+
return hash(builtins.int(self))
516516

517517

518518
class SymFloat:
@@ -550,6 +550,9 @@ def __rfloordiv__(self, other):
550550
def __bool__(self):
551551
return self.node.bool_()
552552

553+
def __float__(self):
554+
return self.node.guard_float("", 0)
555+
553556
# Symbolic power does NOT work with negative base, this is to avoid
554557
# potential complex outputs
555558
def __pow__(self, other):
@@ -612,6 +615,13 @@ def is_integer(self):
612615
def __repr__(self):
613616
return self.node.str()
614617

618+
def __hash__(self):
619+
if self.node.is_constant():
620+
return hash(self.node.float_())
621+
else:
622+
# Force specialization
623+
return hash(builtins.float(self))
624+
615625

616626
class SymBool:
617627
"""
@@ -674,7 +684,8 @@ def __hash__(self):
674684
if self.node.is_constant():
675685
return hash(self.node.bool_())
676686
else:
677-
raise TypeError("unhashable type: SymBool")
687+
# Force specialization
688+
return hash(builtins.bool(self))
678689

679690

680691
def sym_not(a):

torch/_dynamo/variables/lists.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,8 @@ def call_method(
324324
args: List["VariableTracker"],
325325
kwargs: Dict[str, "VariableTracker"],
326326
) -> "VariableTracker":
327+
from .tensor import SymNodeVariable
328+
327329
if name == "append" and self.mutable_local:
328330
assert not kwargs
329331
(arg,) = args
@@ -345,7 +347,10 @@ def call_method(
345347
elif name == "insert" and self.mutable_local:
346348
assert not kwargs
347349
idx, value = args
348-
const_idx = idx.as_python_constant()
350+
if isinstance(idx, SymNodeVariable):
351+
const_idx = idx.evaluate_expr()
352+
else:
353+
const_idx = idx.as_python_constant()
349354
tx.output.side_effects.mutation(self)
350355
self.items.insert(const_idx, value)
351356
return ConstantVariable.create(None)

torch/_inductor/lowering.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,12 @@ def squeeze(x, dim=None):
764764
if dim is None:
765765
return TensorBox(SqueezeView.create(x.data))
766766

767-
dim = canonicalize_dims(len(x.get_size()), dim)
767+
dim = (
768+
V.graph.sizevars.evaluate_static_shape(dim)
769+
if isinstance(dim, (int, sympy.Expr))
770+
else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim)
771+
)
772+
dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload]
768773
dims = set((dim,) if not isinstance(dim, tuple) else dim)
769774

770775
new_shape = []
@@ -1589,7 +1594,7 @@ def unsqueeze_(x, dim):
15891594

15901595

15911596
def _validate_dim(x, dim, offset=0):
1592-
assert isinstance(dim, int)
1597+
dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim))
15931598
ndim = len(x.get_size())
15941599
if dim < 0:
15951600
dim += ndim + offset

torch/_inductor/sizevars.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,9 @@ def evaluate_max(self, left: Expr, right: Expr) -> Expr:
448448
min_val = self.evaluate_min(left, right)
449449
return right if min_val is left else left
450450

451-
def evaluate_static_shape(self, left: Expr) -> int:
451+
def evaluate_static_shape(self, left: Union[Expr, int]) -> int:
452+
if isinstance(left, int):
453+
return left
452454
right = self.size_hint(left)
453455
self.guard_equals(left, sympy.Integer(right))
454456
return int(right)
@@ -461,7 +463,9 @@ def remove_precomputed_replacements(self, expr: Expr) -> Expr:
461463
return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
462464
return expr
463465

464-
def symbolic_hint(self, expr: Expr) -> Union[Expr, int]:
466+
def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]:
467+
if isinstance(expr, int):
468+
return expr
465469
# Substitute all hints into expr, but leave unbacked symints alone
466470
expr = self.simplify(expr)
467471
if not isinstance(expr, Expr):
@@ -476,7 +480,9 @@ def symbolic_hint(self, expr: Expr) -> Union[Expr, int]:
476480
expr = self.remove_precomputed_replacements(expr)
477481
return sympy_subs(expr, self.var_to_val)
478482

479-
def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int:
483+
def size_hint(
484+
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
485+
) -> int:
480486
out = self.symbolic_hint(expr)
481487
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
482488
# Use the provided heuristic fallback hint

0 commit comments

Comments
 (0)