-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
The _mesh_resources.get_root_mesh returns wrong mesh at the end of this following example (torchrun with 8 processes) on torch 2.7.0. It is expected to get mesh1 but instead returns mesh2
Is this a feature or bug?
"""
torchrun --nproc_per_node=8 test.py
"""
import torch
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.device_mesh import init_device_mesh
torch.distributed.init_process_group("nccl")
# global mesh1
mesh1 = init_device_mesh(
"cuda", (1, 4, 2), mesh_dim_names=("a", "b", "c")
)
mesh1_c = mesh1["c"]
# got (1, 4, 2), OK!!
if torch.distributed.get_rank() == 0:
print(_mesh_resources.get_root_mesh(mesh1_c).shape)
# global mesh2
mesh2 = init_device_mesh(
"cuda", (2, 2, 2), mesh_dim_names=("a", "b", "c")
)
mesh2_c = mesh2["c"]
# got (2, 2, 2), it's mesh2 instead of mesh1 !!!
if torch.distributed.get_rank() == 0:
print(_mesh_resources.get_root_mesh(mesh1_c).shape)
# raise error here
assert _mesh_resources.get_root_mesh(mesh1_c) is mesh1
torch.distributed.destroy_process_group()Versions
Collecting environment information...
PyTorch version: 2.7.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 12 (bookworm) (x86_64)
GCC version: (Debian 11.3.0-12) 11.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.36
Python version: 3.11.2 (main, Nov 30 2024, 21:22:50) [GCC 12.2.0] (64-bit runtime)
Python platform: Linux-5.10.135.bsk.6-amd64-x86_64-with-glibc2.36
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H20
GPU 1: NVIDIA H20
GPU 2: NVIDIA H20
GPU 3: NVIDIA H20
GPU 4: NVIDIA H20
GPU 5: NVIDIA H20
GPU 6: NVIDIA H20
GPU 7: NVIDIA H20
Nvidia driver version: 535.161.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.7.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.7.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.7.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.7.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.7.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.7.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.7.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.7.1
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 180
On-line CPU(s) list: 0-179
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8457C
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 45
Socket(s): 2
Stepping: 8
BogoMIPS: 5200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 4.2 MiB (90 instances)
L1i cache: 2.8 MiB (90 instances)
L2 cache: 180 MiB (90 instances)
L3 cache: 195 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-89
NUMA node1 CPU(s): 90-179
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] mypy_extensions==1.1.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] onnx==1.16.1
[pip3] onnx-ir==0.1.9
[pip3] onnxscript==0.3.1
[pip3] optree==0.17.0
[pip3] tbb==2022.2.0
[pip3] tcmlib==1.4.0
[pip3] torch==2.7.0+cu128
[pip3] torchaudio==2.7.0+cu128
[pip3] torchlibrosa==0.1.0
[pip3] torchvision==0.22.0+cu128
[pip3] transformer_engine_torch==2.6.0.post1
[pip3] triton==3.3.0
[conda] Could not collect
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci