Skip to content

Commit d6e39f5

Browse files
committed
Make DDPOptimizer work with torch._dynamo.explain()
GraphModules that were created during DDPOptimizer graph breaking lacked `compile_subgraph_reason`, which caused an exception when running .explain(). Now the reason is provided and users can use .explain() to find out that DDPOptimizer is causing graph breaks. Fixes #94579 ghstack-source-id: fa51cd1 Pull Request resolved: #94749
1 parent 1f7448e commit d6e39f5

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

test/distributed/test_dynamo_distributed.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,13 @@ def opt_fn(inputs):
399399
self.assertTrue(same(correct_outputs, opt_outputs))
400400
self.assertEqual(check_splits_compiler.compiler_called, 3)
401401

402+
# ensure compatibilty with dynamo explain
403+
404+
explain_out = torch._dynamo.explain(ddp_m, inputs)
405+
break_reasons = explain_out[4]
406+
self.assertEqual(len(break_reasons), 3)
407+
self.assertTrue(all(["DDPOptimizer" in r.reason for r in break_reasons]))
408+
402409
@patch.object(config, "optimize_ddp", True)
403410
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
404411
def test_graph_split_inductor(self):

torch/_dynamo/backends/distributed.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
2+
import traceback
23
from dataclasses import dataclass, field
34
from typing import Any, List, Optional
45

56
import torch
67
from torch import fx
8+
from torch._dynamo.output_graph import GraphCompileReason
79
from torch._dynamo.utils import deepcopy_to_fake_tensor, fake_mode_from_tensors
810
from torch.fx.node import Node
911

@@ -54,7 +56,7 @@ def pretty_print_buckets(buckets: List[Bucket]):
5456

5557

5658
class DDPOptimizer:
57-
"""
59+
"""Note [DDPOptimizer]
5860
DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
5961
breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
6062
the boundaries of gradient-allreduce buckets chosen by DDP.
@@ -259,6 +261,14 @@ def forward(self, *args):
259261
sn.args = (sn.args,)
260262

261263
input_mod.recompile()
264+
input_mod.compile_subgraph_reason = GraphCompileReason(
265+
"DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
266+
" Set `torch._dynamo.config.optimize_ddp = False` to disable.",
267+
[
268+
# it's close to useless to get a real stacktrace here, and quite verbose.
269+
traceback.FrameSummary(__file__, 0, DDPOptimizer),
270+
],
271+
)
262272
wrapper = WrapperModule(
263273
self.compiler(input_mod, args),
264274
unwrap_singleton_tuple,

0 commit comments

Comments
 (0)