Skip to content

Commit 63fbd7c

Browse files
williamwen42BoyuanFeng
authored andcommitted
[dynamo] clarify graph break handling/logging in symbolic_convert (#166587)
Pull Request resolved: #166587 Approved by: https://github.com/Lucaskabela ghstack dependencies: #166476, #166477, #166586
1 parent 1b36201 commit 63fbd7c

File tree

4 files changed

+291
-99
lines changed

4 files changed

+291
-99
lines changed

test/dynamo/test_error_messages.py

Lines changed: 124 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,7 @@ def fn(x):
11591159
torch._dynamo.graph_break()
11601160
11611161
NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame.
1162+
11621163
Most recent bytecode instructions traced (max 20):
11631164
TRACE RESUME 0 []
11641165
TRACE LOAD_FAST 'x' []
@@ -1172,7 +1173,8 @@ def fn(x):
11721173
TRACE LOAD_GLOBAL 'torch' []
11731174
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker(unrealized: <class 'module'>)]
11741175
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker(unrealized: <class 'module'>)]
1175-
TRACE CALL 0 [NullVariable, LazyVariableTracker(unrealized: <class 'function'>)]""",
1176+
TRACE CALL 0 [NullVariable, LazyVariableTracker(unrealized: <class 'function'>)]
1177+
""",
11761178
)
11771179

11781180
@torch._dynamo.config.patch(verbose=True)
@@ -1234,17 +1236,28 @@ def f3(x):
12341236
self.assertIn("Foo().attr = x # 1", records[-1].getMessage())
12351237

12361238
def post_munge(s):
1237-
return re.sub(
1239+
s = re.sub(
12381240
r"torch_dynamo_resume_in_f(\d)_at_(\d+)",
12391241
r"torch_dynamo_resume_in_f\1_at_N",
12401242
s,
12411243
)
1244+
# remove most recent bytecode instructions
1245+
# DOTALL is needed to entirely remove TRACE ... lines (including the newline)
1246+
return re.sub(r"TRACE.*$", "", s, flags=re.DOTALL)
12421247

12431248
self.assertExpectedInline(
12441249
post_munge(munge_exc(records[-1].getMessage(), skip=0)),
12451250
"""\
12461251
Graph break in user code at test_error_messages.py:N
1247-
Graph Break Reason: STORE_ATTR-caused graph break
1252+
Graph Break Reason: Encountered graph break when attempting to store an object's attribute (STORE_ATTR):
1253+
1254+
Call to `torch._dynamo.graph_break()`
1255+
Explanation: User-inserted graph break. Message: None
1256+
Hint: Remove the `torch._dynamo.graph_break()` call.
1257+
1258+
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
1259+
1260+
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
12481261
User code traceback:
12491262
File "test_error_messages.py", line N, in test_graph_break_traceback_above_dynamo_shows_user_code
12501263
f3(torch.randn(3))
@@ -1257,8 +1270,12 @@ def post_munge(s):
12571270
12581271
File "test_error_messages.py", line N, in torch_dynamo_resume_in_f3_at_N
12591272
Foo().attr = x
1273+
File "test_error_messages.py", line N, in __setattr__
1274+
torch._dynamo.graph_break()
12601275
12611276
NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame.
1277+
1278+
Most recent bytecode instructions traced (max 20):
12621279
""",
12631280
)
12641281

@@ -1483,6 +1500,110 @@ def bad_clean_and_assemble_instructions(instructions, *args):
14831500
):
14841501
fn(torch.randn(3))
14851502

1503+
@make_logging_test(graph_breaks=True)
1504+
def test_step_graph_break(self, records):
1505+
@torch.compile(backend="eager")
1506+
def fn(x):
1507+
x = x + 1
1508+
x = x + 2
1509+
torch._dynamo.step_unsupported()
1510+
return x + 4
1511+
1512+
fn(torch.ones(3))
1513+
1514+
self.assertExpectedInline(
1515+
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
1516+
"""\
1517+
Graph break in user code at test_error_messages.py:N
1518+
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
1519+
1520+
User code traceback:
1521+
File "test_error_messages.py", line N, in test_step_graph_break
1522+
fn(torch.ones(3))
1523+
File "test_error_messages.py", line N, in fn
1524+
torch._dynamo.step_unsupported()
1525+
""",
1526+
)
1527+
1528+
torch._dynamo.reset()
1529+
1530+
with torch._dynamo.error_on_graph_break(True):
1531+
self.assertExpectedInlineMunged(
1532+
Unsupported,
1533+
lambda: fn(torch.ones(3)),
1534+
"""\
1535+
cannot resume from torch._dynamo.step_unsupported()
1536+
Explanation: traced torch._dynamo.step_unsupported(), but Dynamo is instructed to error on graph break. This graph break is used for debugging only.
1537+
Hint: Remove the torch._dynamo.step_unsupported() call.
1538+
Hint: Make sure fullgraph=False and error_on_graph_break=False.
1539+
Hint: This is likely to be a Dynamo bug. Please report an issue to PyTorch.
1540+
1541+
Developer debug context:
1542+
1543+
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0283.html
1544+
1545+
from user code:
1546+
File "test_error_messages.py", line N, in fn
1547+
torch._dynamo.step_unsupported()""",
1548+
)
1549+
1550+
@make_logging_test(graph_breaks=True)
1551+
def test_store_attr_graph_break(self, records):
1552+
class Foo:
1553+
def __setattr__(self, name, value):
1554+
torch._dynamo.graph_break()
1555+
1556+
@torch.compile(backend="eager")
1557+
def fn(x):
1558+
Foo().attr = x
1559+
1560+
fn(torch.ones(3))
1561+
1562+
self.assertExpectedInline(
1563+
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
1564+
"""\
1565+
Graph break in user code at test_error_messages.py:N
1566+
Graph Break Reason: Encountered graph break when attempting to store an object's attribute (STORE_ATTR):
1567+
1568+
Call to `torch._dynamo.graph_break()`
1569+
Explanation: User-inserted graph break. Message: None
1570+
Hint: Remove the `torch._dynamo.graph_break()` call.
1571+
1572+
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
1573+
1574+
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
1575+
User code traceback:
1576+
File "test_error_messages.py", line N, in test_store_attr_graph_break
1577+
fn(torch.ones(3))
1578+
File "test_error_messages.py", line N, in fn
1579+
Foo().attr = x
1580+
File "test_error_messages.py", line N, in __setattr__
1581+
torch._dynamo.graph_break()
1582+
""",
1583+
)
1584+
1585+
torch._dynamo.reset()
1586+
1587+
with torch._dynamo.error_on_graph_break(True):
1588+
self.assertExpectedInlineMunged(
1589+
Unsupported,
1590+
lambda: fn(torch.ones(3)),
1591+
"""\
1592+
Call to `torch._dynamo.graph_break()`
1593+
Explanation: User-inserted graph break. Message: None
1594+
Hint: Remove the `torch._dynamo.graph_break()` call.
1595+
1596+
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
1597+
1598+
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
1599+
1600+
from user code:
1601+
File "test_error_messages.py", line N, in fn
1602+
Foo().attr = x
1603+
File "test_error_messages.py", line N, in __setattr__
1604+
torch._dynamo.graph_break()""",
1605+
)
1606+
14861607

14871608
if __name__ == "__main__":
14881609
from torch._dynamo.test_case import run_tests

torch/_dynamo/exc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ class RecompileLimitExceeded(Unsupported):
273273

274274
# debug exception thrown when tracing torch._dynamo.step_unsupported()
275275
class StepUnsupported(TorchDynamoException):
276-
pass
276+
def __init__(self) -> None:
277+
self.real_stack = torch._guards.TracingContext.extract_stack()
277278

278279

279280
class UnsafeScriptObjectError(TorchDynamoException):

torch/_dynamo/graph_break_registry.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,5 +2851,17 @@
28512851
"Move the Placement usage outside the compiled region"
28522852
]
28532853
}
2854+
],
2855+
"GB0283": [
2856+
{
2857+
"Gb_type": "cannot resume from torch._dynamo.step_unsupported()",
2858+
"Context": "",
2859+
"Explanation": "traced torch._dynamo.step_unsupported(), but Dynamo is instructed to error on graph break. This graph break is used for debugging only.",
2860+
"Hints": [
2861+
"Remove the torch._dynamo.step_unsupported() call.",
2862+
"Make sure fullgraph=False and error_on_graph_break=False.",
2863+
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
2864+
]
2865+
}
28542866
]
28552867
}

0 commit comments

Comments
 (0)