Skip to content

Commit 8acaa28

Browse files
mrshenlifacebook-github-bot
authored andcommitted
Make CUDACachingAllocator::recordStream() a no-op on null ptrs (#20658)
Summary: Fixes #20651 Communication collectives in `torch.distributed` call `CUDACachingAllocator::recordStream()` on input and output tensors to prevent their memory blocks being freed too early. `CUDACachingAllocator` uses tensor's data pointer to track memory blocks, which does not accept null pointers. However, empty tensor's `storage().data()` might be null. In this case, as there is no associated memory block for the empty tensor, it should be fine to make `recordStream()` a no-op. Tests only cover `broadcast` empty tensors for GLOO backend, because GLOO does not support empty inputs (pytorch/gloo/issues/179). It can be addressed in either `ProcessGroupGloo` or GLOO itself. Will add more tests when that gap is filled. Pull Request resolved: #20658 Differential Revision: D15399371 Pulled By: mrshenli fbshipit-source-id: d29ebd1c72fddae49531f32695f81b89e42e5a4d
1 parent 0719714 commit 8acaa28

File tree

2 files changed

+46
-10
lines changed

2 files changed

+46
-10
lines changed

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,21 @@ struct THCCachingAllocator
378378

379379
void recordStream(void* ptr, cuda::CUDAStream stream)
380380
{
381-
std::lock_guard<std::recursive_mutex> lock(mutex);
382-
Block* block = find_allocated_block(ptr);
383-
if (!block) {
384-
AT_ERROR("invalid device pointer: ", ptr);
385-
}
386-
if (stream.stream() == block->stream) {
387-
// ignore uses on the allocation stream, since those don't require any
388-
// special synchronization
389-
return;
381+
// Empty tensor's storage().data() might be a null ptr. As there is no
382+
// blocks associated with those tensors, it is fine to do nothing here.
383+
if (ptr) {
384+
std::lock_guard<std::recursive_mutex> lock(mutex);
385+
Block* block = find_allocated_block(ptr);
386+
if (!block) {
387+
AT_ERROR("invalid device pointer: ", ptr);
388+
}
389+
if (stream.stream() == block->stream) {
390+
// ignore uses on the allocation stream, since those don't require any
391+
// special synchronization
392+
return;
393+
}
394+
block->stream_uses.insert(stream);
390395
}
391-
block->stream_uses.insert(stream);
392396
}
393397

394398
/** moves a block into a pool of cached free blocks */

test/test_c10d.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,14 @@ def opts(self, threads=2):
579579
opts.threads = threads
580580
return opts
581581

582+
def test_empty_tensors(self):
583+
store = c10d.FileStore(self.file.name, self.world_size)
584+
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
585+
586+
xs = [torch.FloatTensor([])]
587+
pg.broadcast(xs).wait()
588+
self.assertEqual(0, xs[0].numel())
589+
582590
def test_broadcast_checks(self):
583591
store = c10d.FileStore(self.file.name, self.world_size)
584592
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
@@ -1344,6 +1352,30 @@ def setUp(self):
13441352
def tearDown(self):
13451353
pass
13461354

1355+
def test_empty_tensors(self):
1356+
store = c10d.FileStore(self.file.name, self.world_size)
1357+
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
1358+
1359+
xs = [torch.cuda.FloatTensor([])]
1360+
pg.broadcast(xs).wait()
1361+
self.assertEqual(0, xs[0].numel())
1362+
1363+
pg.allreduce(xs).wait()
1364+
self.assertEqual(0, xs[0].numel())
1365+
1366+
pg.reduce(xs).wait()
1367+
self.assertEqual(0, xs[0].numel())
1368+
1369+
ys = [[torch.cuda.FloatTensor([]) for _ in range(self.world_size)]]
1370+
pg.allgather(ys, xs).wait()
1371+
for y in ys[0]:
1372+
self.assertEqual(0, y.numel())
1373+
1374+
ys = [torch.cuda.FloatTensor([])]
1375+
xs = [[torch.cuda.FloatTensor([]) for _ in range(self.world_size)]]
1376+
pg.reduce_scatter(ys, xs).wait()
1377+
self.assertEqual(0, ys[0].numel())
1378+
13471379
def test_broadcast_ops(self):
13481380
store = c10d.FileStore(self.file.name, self.world_size)
13491381
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

0 commit comments

Comments
 (0)