Skip to content

Commit 25d1496

Browse files
mrshenlifacebook-github-bot
authored andcommitted
Fix Process Group for tensors shared across processes (#21449)
Summary: Ops on a Process Group (pg) instance will hit an error when input/output tensors are created on a different process, because, pg calls `recordStream` on `CUDACachingAllocator` which only knows tensors created within the same process. The proposed solution is to add a `suppressError` arg (suggestions for better names?) to `recordStream`. See comments in code for arguments. CC pichuang1984 Pull Request resolved: #21449 Differential Revision: D15689736 Pulled By: mrshenli fbshipit-source-id: e7fc81b167868f8666536067eaa7ae2c8584d88e
1 parent 50ee1f3 commit 25d1496

File tree

4 files changed

+203
-8
lines changed

4 files changed

+203
-8
lines changed

.jenkins/pytorch/multigpu-test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ fi
2929

3030
time python test/run_test.py --verbose -i distributed
3131
time python test/run_test.py --verbose -i c10d
32+
time python test/run_test.py --verbose -i c10d_spawn
3233
assert_git_not_dirty

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -383,15 +383,16 @@ struct THCCachingAllocator
383383
if (ptr) {
384384
std::lock_guard<std::recursive_mutex> lock(mutex);
385385
Block* block = find_allocated_block(ptr);
386-
if (!block) {
387-
AT_ERROR("invalid device pointer: ", ptr);
388-
}
389-
if (stream.stream() == block->stream) {
390-
// ignore uses on the allocation stream, since those don't require any
391-
// special synchronization
392-
return;
386+
// block could be nullptr in some cases, e.g., tensor loaded from blob, or
387+
// shared from another process, or not pointing to a CUDA tensor.
388+
if (block) {
389+
if (stream.stream() == block->stream) {
390+
// ignore uses on the allocation stream, since those don't require any
391+
// special synchronization
392+
return;
393+
}
394+
block->stream_uses.insert(stream);
393395
}
394-
block->stream_uses.insert(stream);
395396
}
396397
}
397398

test/run_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
'autograd',
2222
'cpp_extensions',
2323
'c10d',
24+
'c10d_spawn',
2425
'cuda',
2526
'cuda_primary_ctx',
2627
'dataloader',

test/test_c10d_spawn.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import sys
2+
import tempfile
3+
import unittest
4+
5+
import torch
6+
import torch.distributed as c10d
7+
import torch.multiprocessing as mp
8+
9+
from common_cuda import TEST_MULTIGPU
10+
from common_utils import TestCase, load_tests, run_tests
11+
from common_utils import NO_MULTIPROCESSING_SPAWN
12+
13+
# load_tests from common_utils is used to automatically filter tests for
14+
# sharding on sandcastle. This line silences flake warnings
15+
load_tests = load_tests
16+
17+
if not c10d.is_available():
18+
print('c10d not available, skipping tests')
19+
sys.exit(0)
20+
21+
22+
if NO_MULTIPROCESSING_SPAWN:
23+
print('spawn not available, skipping tests')
24+
sys.exit(0)
25+
26+
27+
NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")
28+
29+
30+
class ProcessGroupShareTensorTest(TestCase):
31+
32+
world_size = 2
33+
34+
@classmethod
35+
def opts(cls, threads=2):
36+
opts = c10d.ProcessGroupGloo.Options()
37+
opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
38+
opts.timeout = 5.0
39+
opts.threads = threads
40+
return opts
41+
42+
@classmethod
43+
def _init_pg_gloo(cls, rank, filename, world_size):
44+
store = c10d.FileStore(filename, world_size)
45+
return c10d.ProcessGroupGloo(
46+
store, rank, world_size, ProcessGroupShareTensorTest.opts())
47+
48+
@classmethod
49+
def _init_pg_nccl(cls, rank, filename, world_size):
50+
store = c10d.FileStore(filename, world_size)
51+
return c10d.ProcessGroupNCCL(store, rank, world_size)
52+
53+
def _test_multiprocess(self, f, shared_tensors, init_pg, n_output):
54+
ws = self.world_size
55+
# file store will delete the test file on destruction
56+
file = tempfile.NamedTemporaryFile(delete=False)
57+
ctx = mp.get_context('spawn')
58+
c2p = ctx.Queue(2)
59+
p2c = ctx.Queue(2)
60+
ps = []
61+
for i in range(ws):
62+
p = ctx.Process(
63+
target=f,
64+
args=(i, file.name, shared_tensors, ws, init_pg, c2p, p2c))
65+
66+
p.start()
67+
ps.append(p)
68+
69+
for _ in range(ws * n_output):
70+
pid, expected, result = c2p.get()
71+
self.assertEqual(
72+
expected,
73+
result,
74+
(
75+
"Expect rank {} to broadcast result {} but got {}."
76+
).format(pid, expected, result)
77+
)
78+
79+
for _ in range(ws):
80+
p2c.put(0)
81+
82+
for p in ps:
83+
p.join(2)
84+
85+
# Why classmethod? multiprocessing cannot pickle TestCase subclass when in
86+
# spawn mode. See https://bugs.python.org/issue33884.
87+
@classmethod
88+
def _test_broadcast_process(
89+
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
90+
pg = init_pg(rank, filename, world_size)
91+
xs = [shared_tensors[rank]]
92+
pg.broadcast(xs).wait()
93+
c2p.put((rank, torch.zeros(2, 2), xs[0].to("cpu")))
94+
p2c.get()
95+
96+
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
97+
def test_shared_broadcast_gloo(self):
98+
self._test_multiprocess(
99+
ProcessGroupShareTensorTest._test_broadcast_process,
100+
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
101+
ProcessGroupShareTensorTest._init_pg_gloo,
102+
1)
103+
104+
105+
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
106+
@unittest.skipIf(NO_NCCL, "NCCL needed")
107+
def test_shared_broadcast_nccl(self):
108+
self._test_multiprocess(
109+
ProcessGroupShareTensorTest._test_broadcast_process,
110+
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
111+
ProcessGroupShareTensorTest._init_pg_nccl,
112+
1)
113+
114+
@classmethod
115+
def _test_allreduce_process(
116+
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
117+
pg = init_pg(rank, filename, world_size)
118+
xs = [shared_tensors[rank]]
119+
pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait()
120+
c2p.put((rank, torch.ones(2, 2) * 2, xs[0].to("cpu")))
121+
p2c.get()
122+
123+
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
124+
def test_shared_allreduce_gloo(self):
125+
self._test_multiprocess(
126+
ProcessGroupShareTensorTest._test_allreduce_process,
127+
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
128+
ProcessGroupShareTensorTest._init_pg_gloo,
129+
1)
130+
131+
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
132+
@unittest.skipIf(NO_NCCL, "NCCL needed")
133+
def test_shared_allreduce_nccl(self):
134+
self._test_multiprocess(
135+
ProcessGroupShareTensorTest._test_allreduce_process,
136+
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
137+
ProcessGroupShareTensorTest._init_pg_nccl,
138+
1)
139+
140+
@classmethod
141+
def _test_reduce_process(
142+
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
143+
pg = init_pg(rank, filename, world_size)
144+
x = shared_tensors[rank]
145+
pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
146+
if rank == 0:
147+
c2p.put((rank, torch.ones(2, 2) * 2, x.to("cpu")))
148+
else:
149+
c2p.put((rank, torch.ones(2, 2), x.to("cpu")))
150+
p2c.get()
151+
152+
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
153+
@unittest.skipIf(NO_NCCL, "NCCL needed")
154+
def test_shared_reduce_nccl(self):
155+
self._test_multiprocess(
156+
ProcessGroupShareTensorTest._test_reduce_process,
157+
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
158+
ProcessGroupShareTensorTest._init_pg_nccl,
159+
1)
160+
161+
@classmethod
162+
def _test_allgather_process(
163+
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
164+
pg = init_pg(rank, filename, world_size)
165+
xs = [shared_tensors[rank]]
166+
ys = [[torch.zeros_like(xs[0]) for i in range(world_size)]]
167+
pg.allgather(ys, xs).wait()
168+
for i in range(world_size):
169+
c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu")))
170+
171+
p2c.get()
172+
173+
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
174+
def test_shared_allgather_gloo(self):
175+
self._test_multiprocess(
176+
ProcessGroupShareTensorTest._test_allgather_process,
177+
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
178+
ProcessGroupShareTensorTest._init_pg_gloo,
179+
self.world_size)
180+
181+
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
182+
@unittest.skipIf(NO_NCCL, "NCCL needed")
183+
def test_shared_allgather_nccl(self):
184+
self._test_multiprocess(
185+
ProcessGroupShareTensorTest._test_allgather_process,
186+
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
187+
ProcessGroupShareTensorTest._init_pg_nccl,
188+
self.world_size)
189+
190+
191+
if __name__ == '__main__':
192+
run_tests()

0 commit comments

Comments
 (0)