Skip to content

Commit 53e71fa

Browse files
voznesenskympytorchmergebot
authored andcommitted
Add shape_env guards to tracing context (#90876)
Pull Request resolved: #90876 Approved by: https://github.com/Chillee, https://github.com/ezyang
1 parent a01c1ee commit 53e71fa

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

torch/_dynamo/output_graph.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,24 @@ def __init__(
196196
super(OutputGraph, self).__init__()
197197
self.graph = torch.fx.Graph()
198198
self.graphargs: List[GraphArg] = []
199-
self.tracing_context: TracingContext = TracingContext()
199+
fake_mode = torch._subclasses.FakeTensorMode(
200+
throw_on_data_dependent_ops=True,
201+
shape_env=ShapeEnv() if config.dynamic_shapes else None,
202+
)
203+
self.tracing_context: TracingContext = TracingContext(fake_mode)
200204
# tracked_fakes says where any tensor that was wrapped to fake came
201205
# from. It is similar to GraphArg, in that all GraphArgs will get
202206
# will get added to TrackedFakes, but TrackedFakes also contains
203207
# GraphArgs that got pruned, and things like Tensor attributes which
204208
# aren't explicit graph inputs. Used by shape guard
205209
self.tracked_fakes: List[TrackedFake] = []
210+
# Although we prune unused graphargs before sending graphs to
211+
# compilers, we may have legitimately triggered shape guards
212+
# on "unused" inputs that we must keep track of. So after
213+
# remove_unused_graphargs is called, orig_graphargs and
214+
# graphargs no longer alias; orig_graphargs is the original
215+
# graphargs, and graphargs is the pruned list. Guard creation
216+
# should use original graphargs.
206217
self.orig_graphargs: List[GraphArg] = self.graphargs
207218
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
208219
self.side_effects = SideEffects()
@@ -228,7 +239,6 @@ def __init__(
228239
self.unspec_variable_map: Dict[
229240
str, Union[UnspecializedNumpyVariable, UnspecializedPythonVariable]
230241
] = {}
231-
self.shape_env = ShapeEnv() if config.dynamic_shapes else None
232242
self.intermediary_symbols: Dict[sympy.Expr, None] = {}
233243

234244
# Enables creating unique node names by tracking
@@ -245,6 +255,10 @@ def output(self):
245255
def fake_mode(self):
246256
return self.root_tx.fake_mode
247257

258+
@property
259+
def shape_env(self):
260+
return self.tracing_context.fake_mode.shape_env
261+
248262
@property
249263
def guards(self) -> Set[Guard]:
250264
return self.tracing_context.guards_context.dynamo_guards

torch/_dynamo/symbolic_convert.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,10 +1577,7 @@ def __init__(
15771577
# Flag to indicate whether tracing is used for export.
15781578
self.export = export
15791579

1580-
self._fake_mode = torch._subclasses.FakeTensorMode(
1581-
throw_on_data_dependent_ops=True,
1582-
shape_env=output.shape_env,
1583-
)
1580+
self._fake_mode = output.tracing_context.fake_mode
15841581

15851582
self.checkpoint = None
15861583
self.random_calls = []

torch/_guards.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,9 @@ class TracingContext:
285285
def get() -> Optional["TracingContext"]:
286286
return _CURRENT_TRACING_CONTEXT
287287

288-
def __init__(self):
288+
def __init__(self, fake_mode):
289289
self.guards_context = GuardsContext()
290+
self.fake_mode = fake_mode
290291

291292

292293
"""

torch/fx/experimental/symbolic_shapes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def track_symint(source, val):
753753
try:
754754
exprs.append(ShapeGuardPrinter(symbol_to_source).doprint(g))
755755
except Exception:
756-
log.warning(f"Failing guard allocated at:\n{tb}")
756+
log.warning(f"Failing guard allocated at: \n{tb}")
757757
raise
758758

759759
# 3. Every symbol must not be equal to 0/1
@@ -821,15 +821,15 @@ def bind_symint(arg, val):
821821
return bindings
822822

823823
def get_nontrivial_guards(self):
824-
return [self.simplify(guard) for guard, _ in self.guards if self._maybe_evaluate_static(guard) is None]
824+
return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr) is None]
825825

826826
def format_guards(self, verbose=False):
827827
def format_tb(tb):
828828
if not verbose:
829829
return ""
830830
return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
831831

832-
return '\n'.join(f" - {guard}{format_tb(tb)}" for guard, tb in self.guards)
832+
return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
833833

834834
def get_shape_groups(self):
835835
shape_groups = collections.defaultdict(list)

0 commit comments

Comments
 (0)