Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion test/distributed/test_dynamo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def __init__(self):
self.weight = nn.Parameter(torch.randn(512, 512))

def forward(self, x):
return torch.mm(x, self.weight.t())
tmp = torch.mm(x, self.weight.t())
# test an edge case where torch.where.scalar was decomposed to aten.where.self(tensor, tensor, tensor)
# and the tensors T(0.4) and T(0.5) were not wrapped in FakeTensors during DDPOptimizer compilation
return tmp + torch.where(tmp < 0.5, 0.3, 0.6)

class MyLinear(torch.nn.Module):
def __init__(self):
Expand Down
21 changes: 16 additions & 5 deletions torch/_dynamo/optimizations/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,6 @@ def run_node(self, n: Node) -> Any:
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)

# modify the currently running FX graph
# maybe this isn't sound in general, but only changing the target of a node might be ok?
if n.op == "call_module":
real_mod = self.fetch_attr(n.target)
if fake_mode:
Expand All @@ -308,15 +306,28 @@ def run_node(self, n: Node) -> Any:
log.debug(
f"\n---{n.target} graph---\n" + str(curr_submod.graph)
)

# When calling the compiler on the submod, inputs (new_args) are expected to
# be FakeTensors already since Dynamo would have made them FakeTensors in the
# non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
# since this wrapping happens during compilation
compiled_submod_real = self.compile_submod(
real_mod, new_args, kwargs
)

# We update the original (outer) graph with a call into the compiled module
# instead of the uncompiled one.
self.module.delete_submodule(n.target)
n.target = "compiled_" + n.target
self.module.add_submodule(n.target, compiled_submod_real)
return curr_submod(*new_args, **kwargs)
# then we execute the modified node using the usual logic
return getattr(self, n.op)(n.target, new_args, kwargs)

# Finally, we have to produce inputs for use compiling the next submodule,
# and these need to be FakeTensors, so we execute the module under fake_mode
with fake_mode:
return curr_submod(*new_args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is executed at runtime, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean.

Since torch dynamo defers its compilation until the first execution, then in a way yes, this code happens "at runtime".

But this code only happens as a part of the compilation flow, which in a simple (static model) scenario only happens once. The second time a user calls their compiled ddp model, none of this code should run, since we're not recompiling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you confused it with 'WrapperModule.forwrad' - that's the only piece of code in the whole `ddp_optimizer' file that I'd expect to run repeatedly on every runtime. (all it does is unwrap the tuple output from the compiled subgraph)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, wrappermodule.forward is the place I was thinking of.

This looks fine to me.

When Ed and I were working on it - it was very confusing which part of this was compile time, and which was runtime.

else:
# placeholder or output nodes don't need to get compiled, just executed
return getattr(self, n.op)(n.target, new_args, kwargs)

submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn)
submod_compiler.run(*example_inputs)
Expand Down