Skip to content

torch.where + DDPoptimizer + Dynamo causes faketensor error #92941

@wconstab

Description

@wconstab

🐛 Describe the bug

Summary

When torch.where(tensor, c, d) with scalar constants c and d appears in a submod graph of DDP graph optimizer, it the constants c, d appear to get interpreted as Tensors that are not FakeTensors. It is unclear why they are being interpreted as tensors in the first place.

Something seems to be going wrong in the faketensor layer around DDPOptimizer and its interface calling into the submod compiler. (cc @ezyang @soumith @msaroufim @ngimel @bdhirsh @voznesenskym)

Repro

import torch
import os
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
    os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
    os.environ["RANK"] = os.getenv("RANK", "0")
    os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
    dist.init_process_group("nccl")


class MyModule(torch.nn.Module):
    def __init__(self, a, b):
        super(MyModule, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(a, b),
            nn.ReLU(),
        )

    def forward(self, x):
        tmp = self.net(x)
        return torch.where(tmp <= 0.5, 0.4, 1.0)


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net = nn.Sequential(
            *[nn.Linear(10, 100000)]
            + [MyModule(100000, 5)]
        )

    def forward(self, x):
        return self.net(x)

setup(0, 1)
model = ToyModel()
model.cuda()
inputs = (torch.randn(20, 10, device="cuda"),)

omodel = torch.compile(DDP(model))
omodel(*inputs)

gives

Traceback (most recent call last):
  File "repro.py", line 45, in <module>
    omodel(*inputs)
  File "/scratch/whc/work/pytorch/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/whc/work/pytorch/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/scratch/whc/work/pytorch/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/scratch/whc/work/pytorch/torch/nn/parallel/distributed.py", line 1158, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/scratch/whc/work/pytorch/torch/nn/parallel/distributed.py", line 1114, in _run_ddp_forward
    return module_to_run(*inputs, **kwargs)
  File "/scratch/whc/work/pytorch/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/whc/work/pytorch/torch/_dynamo/eval_frame.py", line 329, in catch_errors
    return hijacked_callback(frame, cache_size, hooks)
  File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 102, in _fn
    return fn(*args, **kwargs)
  File "/scratch/whc/work/pytorch/torch/_dynamo/utils.py", line 96, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 260, in _convert_frame_assert
    return _compile(
  File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 321, in _compile
    out_code = transform_code_object(code, transform)
  File "/scratch/whc/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/scratch/whc/work/pytorch/torch/_dynamo/convert_frame.py", line 308, in transform
    tracer.run()
  File "/scratch/whc/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1692, in run
    super().run()
  File "/scratch/whc/work/pytorch/torch/_dynamo/symbolic_convert.py", line 538, in run
    and self.step()
  File "/scratch/whc/work/pytorch/torch/_dynamo/symbolic_convert.py", line 501, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/whc/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1758, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/scratch/whc/work/pytorch/torch/_dynamo/output_graph.py", line 527, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/scratch/whc/work/pytorch/torch/_dynamo/output_graph.py", line 598, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/scratch/whc/work/pytorch/torch/_dynamo/output_graph.py", line 674, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
torch._dynamo.exc.BackendCompilerFailed: compile_fn raised BdbQuit: While executing %submod_1 : [#users=1] = call_module[target=submod_1](args = (%submod_0,), kwargs = {})
Original traceback:
None

Versions

repro'd locally on a 1/20 snapshot of master only (9db4323), but issue has been reported more recently on internal fbcode builds.

cc @ezyang @soumith @msaroufim @ngimel @bdhirsh

Metadata

Metadata

Assignees

Labels

oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions