Skip to content

[Dynamo] Graph Break Introduced by DDP Failed to Explain #94579

@alpha0422

Description

@alpha0422

🐛 Describe the bug

We find that explaining the graph break reasons with torch._dynamo.explain() raises an exception, and seems graph breaks introduced by DDP doesn't set the reason attribute.
Here's a minimal example to reproduce:

#!/usr/bin/env python                                             
                                                                  
# Usage: torchrun --nnodes=1 --nproc_per_node=1 bug.py            
                                                                  
import torch                                                      
from torch.nn.parallel import DistributedDataParallel as DDP      
from torch._dynamo import explain                                 
torch.manual_seed(0)                                              
                                                                  
class Model(torch.nn.Module):                                     
    def __init__(self, channels=320):                             
        super().__init__()                                        
        self.channels = channels                                  
        self.layers = torch.nn.Sequential(                        
            torch.nn.Linear(channels, 32*channels),               
            torch.nn.Linear(32*channels, 32*channels),            
            torch.nn.Linear(32*channels, channels),               
        )                                                         
                                                                  
    def forward(self, x ):                                        
        return self.layers(x)                                     
                                                                  
torch.distributed.init_process_group("nccl", rank=0, world_size=1)
                                                                  
x = torch.randn((320,), device="cuda")                            
model = Model().cuda()                                            
ddp_model = DDP(model, bucket_cap_mb=1)                           
                                                                  
ddp_model(x)                                                      
                                                                  
explanation, out_guards, graphs, ops_per_graph, break_reasons, \  
    explanation_verbose = explain(ddp_model, x)                   
print(explanation)                                                

The error message is:

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 584, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/optimizations/distributed.py", line 315, in compile_fn
    submod_compiler.run(*example_inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 130, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/optimizations/distributed.py", line 304, in run_node
    compiled_submod_real = self.compile_submod(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/optimizations/distributed.py", line 250, in compile_submod
    self.compiler(input_mod, args),
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/debug_utils.py", line 915, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 471, in dynamo_graph_accumulating_compiler
    if gm.compile_subgraph_reason is not None:
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1587, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'GraphModule' object has no attribute 'compile_subgraph_reason'

While executing %submod_0 : [#users=1] = call_module[target=submod_0](args = (%x,), kwargs = {})
Original traceback:
None

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "bug-7.py", line 32, in <module>
    explanation_verbose = explain(ddp_model, x)
  File "/usr/lib/python3.8/unittest/mock.py", line 1325, in patched
    return func(*newargs, **newkeywargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 486, in explain
    opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 80, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1098, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1054, in _run_ddp_forward
    return module_to_run(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 326, in catch_errors                                                                                                                 return hijacked_callback(frame, cache_size)                                                                                                                                                                      File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 468, in _convert_frame
    result = inner_convert(frame, cache_size)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 102, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 90, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 395, in _compile
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 382, in transform
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1620, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 484, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 453, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1682, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 439, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 510, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 589, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: compile_fn raised AttributeError: 'GraphModule' object has no attribute 'compile_subgraph_reason'

While executing %submod_0 : [#users=1] = call_module[target=submod_0](args = (%x,), kwargs = {})
Original traceback:
None

Set torch._dynamo.config.verbose=True for more information

Versions

Collecting environment information...                                                                                                                                                                    [34/17196]
PyTorch version: 1.14.0a0+44dac51
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.24.1
Libc version: glibc-2.31

Python version: 3.8.10 (default, Nov 14 2022, 12:59:47)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-99-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB
Nvidia driver version: 515.65.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   43 bits physical, 48 bits virtual
CPU(s):                          128
On-line CPU(s) list:             0-127
Thread(s) per core:              1
Core(s) per socket:              64
Socket(s):                       2
NUMA node(s):                    8
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7742 64-Core Processor
Stepping:                        0
Frequency boost:                 enabled
CPU MHz:                         3087.518
CPU max MHz:                     2250.0000
CPU min MHz:                     1500.0000
BogoMIPS:                        4491.73
Virtualization:                  AMD-V

L1d cache:                       4 MiB
L1i cache:                       4 MiB
L2 cache:                        64 MiB
L3 cache:                        512 MiB
NUMA node0 CPU(s):               0-15
NUMA node1 CPU(s):               16-31
NUMA node2 CPU(s):               32-47
NUMA node3 CPU(s):               48-63
NUMA node4 CPU(s):               64-79
NUMA node5 CPU(s):               80-95
NUMA node6 CPU(s):               96-111
NUMA node7 CPU(s):               112-127
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonst
op_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ib
s skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate sme ssbd mba sev ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt cl
wb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausef
ilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.22.2
[pip3] pytorch-quantization==2.1.2
[pip3] torch==1.14.0a0+44dac51
[pip3] torch-tensorrt==1.4.0.dev0
[pip3] torchtext==0.13.0a0+fae8e8c
[pip3] torchvision==0.15.0a0
[conda] Could not collect

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions