Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion torch/_guards.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import dataclasses
import enum
import logging
import weakref
from typing import Callable, List, Optional
from typing import Callable, List, NamedTuple, Optional

# TODO(voz): Stolen pattern, not sure why this is the case,
# but mypy complains.
try:
import sympy # type: ignore[import]
except ImportError:
logging.warning("No sympy found")

"""
torch._guards is the definitional source of truth for general purpose guard structures.
Expand Down Expand Up @@ -52,6 +60,11 @@ class GuardBuilderBase:
pass


class ShapeGuard(NamedTuple):
expr: sympy.Expr
stack: str


@dataclasses.dataclass
class Guard:
# The name of a Guard specifies what exactly it is the guard is guarding
Expand Down
9 changes: 5 additions & 4 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import textwrap
import logging
from torch import SymInt, SymFloat
from torch._guards import ShapeGuard

try:
import sympy # type: ignore[import]
Expand Down Expand Up @@ -467,7 +468,7 @@ def _print_Symbol(self, expr) -> str:

class ShapeEnv(object):
def __init__(self):
self.guards = []
self.guards: List[ShapeGuard] = []
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: Dict["sympy.Symbol", "sympy.Integer"] = {}
Expand Down Expand Up @@ -896,9 +897,9 @@ def evaluate_expr(self, expr: "sympy.Expr"):
if not self._suppress_guards_tls():
stack = ''.join(traceback.format_list(traceback.extract_stack()[:-2]))
if concrete_val is sympy.true:
self.guards.append((expr, stack))
self.guards.append(ShapeGuard(expr, stack))
elif concrete_val is sympy.false:
self.guards.append((sympy.Not(expr), stack))
self.guards.append(ShapeGuard(sympy.Not(expr), stack))
else:
self.guards.append((sympy.Eq(expr, concrete_val), stack))
self.guards.append(ShapeGuard(sympy.Eq(expr, concrete_val), stack))
return concrete_val