Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions test/distributed/test_dynamo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,24 @@ def __init__(self):
super(MyModule, self).__init__()
mods = [
(MyLinear(), torch.nn.ReLU()),
# sandwitch the custom in the middle so it comes before and after
# sandwich the custom in the middle so it comes before and after
(MyCustomLinear(), torch.nn.ReLU()),
(MyLinear(), torch.nn.ReLU()),
]
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])

def forward(self, x):
return self.seq(x)
def forward(self, x, y):
# test special case where the 0th bucket (layers close to graph input) is at capacity, which would
# trigger a new bucket, but there are only trivial ops without parameters to put into the new bucket.
# optimize this case by fusing that 'empty bucket' back together with the previous full one
return self.seq(x + y)

m = MyModule().to(device)
m.apply(init_weights)
inputs = torch.rand((512, 512)).to(device)
correct_outputs = m(inputs)
# test duplicated inputs
inputs = (inputs, inputs)
correct_outputs = m(*inputs)
return m, inputs, correct_outputs

def get_hf_bert(rank):
Expand Down Expand Up @@ -520,7 +525,7 @@ def test_custom_layer(self):

@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
return ddp_m(*inputs)

opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
Expand Down Expand Up @@ -563,7 +568,7 @@ def test_ignored_parameters(self):

@torch._dynamo.optimize(ddp_optimizer.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
return ddp_m(*inputs)

opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
Expand Down
6 changes: 6 additions & 0 deletions torch/_dynamo/optimizations/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
# Ignored params still end up in buckets, we just don't count them towards the capacity
buckets[0].nodes.append(node)

if len(buckets) > 1 and buckets[0].size == 0:
# we collected a small preamble graph with ops that don't include parameters, fuse it back
buckets[1].nodes.extend(buckets[0].nodes)
assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
del buckets[0]

# stash buckets for testing/debugging purposes
self.buckets = buckets
log.info(
Expand Down