Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/distributed/test_dynamo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,13 @@ def opt_fn(inputs):
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)

# ensure compatibilty with dynamo explain

explain_out = torch._dynamo.explain(ddp_m, inputs)
break_reasons = explain_out[4]
self.assertEqual(len(break_reasons), 3)
self.assertTrue(all(["DDPOptimizer" in r.reason for r in break_reasons]))

@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor(self):
Expand Down
12 changes: 11 additions & 1 deletion torch/_dynamo/backends/distributed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import traceback
from dataclasses import dataclass, field
from typing import Any, List, Optional

import torch
from torch import fx
from torch._dynamo.output_graph import GraphCompileReason
from torch._dynamo.utils import deepcopy_to_fake_tensor, fake_mode_from_tensors
from torch.fx.node import Node

Expand Down Expand Up @@ -54,7 +56,7 @@ def pretty_print_buckets(buckets: List[Bucket]):


class DDPOptimizer:
"""
"""Note [DDPOptimizer]
DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
the boundaries of gradient-allreduce buckets chosen by DDP.
Expand Down Expand Up @@ -259,6 +261,14 @@ def forward(self, *args):
sn.args = (sn.args,)

input_mod.recompile()
input_mod.compile_subgraph_reason = GraphCompileReason(
"DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
" Set `torch._dynamo.config.optimize_ddp = False` to disable.",
[
# it's close to useless to get a real stacktrace here, and quite verbose.
traceback.FrameSummary(__file__, 0, DDPOptimizer),
],
)
wrapper = WrapperModule(
self.compiler(input_mod, args),
unwrap_singleton_tuple,
Expand Down