-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Summary
When torch.where(tensor, c, d) with scalar constants c and d appears in a submod graph of DDP graph optimizer, it the constants c, d appear to get interpreted as Tensors that are not FakeTensors. It is unclear why they are being interpreted as tensors in the first place.
Something seems to be going wrong in the faketensor layer around DDPOptimizer and its interface calling into the submod compiler. (cc @ezyang @soumith @msaroufim @ngimel @bdhirsh @voznesenskym)
Repro
import torch
import os
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
os.environ["RANK"] = os.getenv("RANK", "0")
os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
dist.init_process_group("nccl")
class MyModule(torch.nn.Module):
def __init__(self, a, b):
super(MyModule, self).__init__()
self.net = nn.Sequential(
nn.Linear(a, b),
nn.ReLU(),
)
def forward(self, x):
tmp = self.net(x)
return torch.where(tmp <= 0.5, 0.4, 1.0)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net = nn.Sequential(
*[nn.Linear(10, 100000)]
+ [MyModule(100000, 5)]
)
def forward(self, x):
return self.net(x)
setup(0, 1)
model = ToyModel()
model.cuda()
inputs = (torch.randn(20, 10, device="cuda"),)
omodel = torch.compile(DDP(model))
omodel(*inputs)
gives
Traceback (most recent call last):
File "repro.py", line 45, in <module>
omodel(*inputs)
File "/scratch/whc/work/pytorch/torch/nn/modules/module.py", line 1488, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/whc/work/pytorch/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/scratch/whc/work/pytorch/torch/_dynamo/eval_frame.py", line 211, in _fn
return fn(*args, **kwargs)
File "/scratch/whc/work/pytorch/torch/nn/parallel/distributed.py", line 1158, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/scratch/whc/work/pytorch/torch/nn/parallel/distributed.py", line 1114, in _run_ddp_forward
return module_to_run(*inputs, **kwargs)
File "/scratch/whc/work/pytorch/torch/nn/modules/module.py", line 1488, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/whc/work/pytorch/torch/_dynamo/eval_frame.py", line 329, in catch_errors
return hijacked_callback(frame, cache_size, hooks)
File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 102, in _fn
return fn(*args, **kwargs)
File "/scratch/whc/work/pytorch/torch/_dynamo/utils.py", line 96, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 260, in _convert_frame_assert
return _compile(
File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 321, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/whc/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 308, in transform
tracer.run()
File "/scratch/whc/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1692, in run
super().run()
File "/scratch/whc/work/pytorch/torch/_dynamo/symbolic_convert.py", line 538, in run
and self.step()
File "/scratch/whc/work/pytorch/torch/_dynamo/symbolic_convert.py", line 501, in step
getattr(self, inst.opname)(inst)
File "/scratch/whc/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1758, in RETURN_VALUE
self.output.compile_subgraph(self)
File "/scratch/whc/work/pytorch/torch/_dynamo/output_graph.py", line 527, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/scratch/whc/work/pytorch/torch/_dynamo/output_graph.py", line 598, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/scratch/whc/work/pytorch/torch/_dynamo/output_graph.py", line 674, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
torch._dynamo.exc.BackendCompilerFailed: compile_fn raised BdbQuit: While executing %submod_1 : [#users=1] = call_module[target=submod_1](args = (%submod_0,), kwargs = {})
Original traceback:
None
Versions
repro'd locally on a 1/20 snapshot of master only (9db4323), but issue has been reported more recently on internal fbcode builds.
Metadata
Metadata
Assignees
Labels
oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module