Skip to content

Commit cac8faa

Browse files
committed
Fix DDOptimizer fake_mode execution
When running compiled submods for the purpose of producing outputs to pass to the compilation step for the next submod, we use fake parameters and assume fake inputs, but we forgot to activate our fake_mode during execution. This caused certain edge cases where tensors other than activations or parameters got created during execution, such as scalar->tensor expansion in the case of executing torch.where(tensor, scalar, scalar). Also add a test and clarify behavior of DDPOptimizer via comments. [ghstack-poisoned]
1 parent b399007 commit cac8faa

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

test/distributed/test_dynamo_distributed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def __init__(self):
6666
self.weight = nn.Parameter(torch.randn(512, 512))
6767

6868
def forward(self, x):
69-
return torch.mm(x, self.weight.t())
69+
tmp = torch.mm(x, self.weight.t())
70+
# test an edge case where torch.where.scalar was decomposed to aten.where.self(tensor, tensor, tensor)
71+
# and the tensors T(0.4) and T(0.5) were not wrapped in FakeTensors during DDPOptimizer compilation
72+
return tmp + torch.where(tmp < 0.5, 0.3, 0.6)
7073

7174
class MyLinear(torch.nn.Module):
7275
def __init__(self):

torch/_dynamo/optimizations/distributed.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,6 @@ def run_node(self, n: Node) -> Any:
296296
assert isinstance(args, tuple)
297297
assert isinstance(kwargs, dict)
298298

299-
# modify the currently running FX graph
300-
# maybe this isn't sound in general, but only changing the target of a node might be ok?
301299
if n.op == "call_module":
302300
real_mod = self.fetch_attr(n.target)
303301
if fake_mode:
@@ -308,15 +306,28 @@ def run_node(self, n: Node) -> Any:
308306
log.debug(
309307
f"\n---{n.target} graph---\n" + str(curr_submod.graph)
310308
)
309+
310+
# When calling the compiler on the submod, inputs (new_args) are expected to
311+
# be FakeTensors already since Dynamo would have made them FakeTensors in the
312+
# non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
313+
# since this wrapping happens during compilation
311314
compiled_submod_real = self.compile_submod(
312315
real_mod, new_args, kwargs
313316
)
317+
318+
# We update the original (outer) graph with a call into the compiled module
319+
# instead of the uncompiled one.
314320
self.module.delete_submodule(n.target)
315321
n.target = "compiled_" + n.target
316322
self.module.add_submodule(n.target, compiled_submod_real)
317-
return curr_submod(*new_args, **kwargs)
318-
# then we execute the modified node using the usual logic
319-
return getattr(self, n.op)(n.target, new_args, kwargs)
323+
324+
# Finally, we have to produce inputs for use compiling the next submodule,
325+
# and these need to be FakeTensors, so we execute the module under fake_mode
326+
with fake_mode:
327+
return curr_submod(*new_args, **kwargs)
328+
else:
329+
# placeholder or output nodes don't need to get compiled, just executed
330+
return getattr(self, n.op)(n.target, new_args, kwargs)
320331

321332
submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn)
322333
submod_compiler.run(*example_inputs)

0 commit comments

Comments
 (0)