-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
🐛 Describe the bug
We find that explaining the graph break reasons with torch._dynamo.explain() raises an exception, and seems graph breaks introduced by DDP doesn't set the reason attribute.
Here's a minimal example to reproduce:
#!/usr/bin/env python
# Usage: torchrun --nnodes=1 --nproc_per_node=1 bug.py
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch._dynamo import explain
torch.manual_seed(0)
class Model(torch.nn.Module):
def __init__(self, channels=320):
super().__init__()
self.channels = channels
self.layers = torch.nn.Sequential(
torch.nn.Linear(channels, 32*channels),
torch.nn.Linear(32*channels, 32*channels),
torch.nn.Linear(32*channels, channels),
)
def forward(self, x ):
return self.layers(x)
torch.distributed.init_process_group("nccl", rank=0, world_size=1)
x = torch.randn((320,), device="cuda")
model = Model().cuda()
ddp_model = DDP(model, bucket_cap_mb=1)
ddp_model(x)
explanation, out_guards, graphs, ops_per_graph, break_reasons, \
explanation_verbose = explain(ddp_model, x)
print(explanation) The error message is:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 584, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/optimizations/distributed.py", line 315, in compile_fn
submod_compiler.run(*example_inputs)
File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 130, in run
self.env[node] = self.run_node(node)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/optimizations/distributed.py", line 304, in run_node
compiled_submod_real = self.compile_submod(
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/optimizations/distributed.py", line 250, in compile_submod
self.compiler(input_mod, args),
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/debug_utils.py", line 915, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 471, in dynamo_graph_accumulating_compiler
if gm.compile_subgraph_reason is not None:
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1587, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'GraphModule' object has no attribute 'compile_subgraph_reason'
While executing %submod_0 : [#users=1] = call_module[target=submod_0](args = (%x,), kwargs = {})
Original traceback:
None
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "bug-7.py", line 32, in <module>
explanation_verbose = explain(ddp_model, x)
File "/usr/lib/python3.8/unittest/mock.py", line 1325, in patched
return func(*newargs, **newkeywargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 486, in explain
opt_f(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1480, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 80, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1098, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1054, in _run_ddp_forward
return module_to_run(*inputs, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1480, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 326, in catch_errors return hijacked_callback(frame, cache_size) File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 468, in _convert_frame
result = inner_convert(frame, cache_size)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 102, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 90, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
return _compile(
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 395, in _compile
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 382, in transform
tracer.run()
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1620, in run
super().run()
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 484, in run
and self.step()
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 453, in step
getattr(self, inst.opname)(inst)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1682, in RETURN_VALUE
self.output.compile_subgraph(self)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 439, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 510, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 589, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: compile_fn raised AttributeError: 'GraphModule' object has no attribute 'compile_subgraph_reason'
While executing %submod_0 : [#users=1] = call_module[target=submod_0](args = (%x,), kwargs = {})
Original traceback:
None
Set torch._dynamo.config.verbose=True for more informationVersions
Collecting environment information... [34/17196]
PyTorch version: 1.14.0a0+44dac51
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.24.1
Libc version: glibc-2.31
Python version: 3.8.10 (default, Nov 14 2022, 12:59:47) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-99-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB
Nvidia driver version: 515.65.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 1
Core(s) per socket: 64
Socket(s): 2
NUMA node(s): 8
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7742 64-Core Processor
Stepping: 0
Frequency boost: enabled
CPU MHz: 3087.518
CPU max MHz: 2250.0000
CPU min MHz: 1500.0000
BogoMIPS: 4491.73
Virtualization: AMD-V
L1d cache: 4 MiB
L1i cache: 4 MiB
L2 cache: 64 MiB
L3 cache: 512 MiB
NUMA node0 CPU(s): 0-15
NUMA node1 CPU(s): 16-31
NUMA node2 CPU(s): 32-47
NUMA node3 CPU(s): 48-63
NUMA node4 CPU(s): 64-79
NUMA node5 CPU(s): 80-95
NUMA node6 CPU(s): 96-111
NUMA node7 CPU(s): 112-127
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonst
op_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ib
s skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate sme ssbd mba sev ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt cl
wb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausef
ilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca
Versions of relevant libraries:
[pip3] numpy==1.22.2
[pip3] pytorch-quantization==2.1.2
[pip3] torch==1.14.0a0+44dac51
[pip3] torch-tensorrt==1.4.0.dev0
[pip3] torchtext==0.13.0a0+fae8e8c
[pip3] torchvision==0.15.0a0
[conda] Could not collect
cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire