Skip to content

Commit d23f4c1

Browse files
committed
Fix Process Group for tensors shared across processes
1 parent 8215f44 commit d23f4c1

File tree

5 files changed

+79
-19
lines changed

5 files changed

+79
-19
lines changed

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -376,22 +376,33 @@ struct THCCachingAllocator
376376
cacheInfoAux(small_blocks, dev_id, total, largest);
377377
}
378378

379-
void recordStream(void* ptr, cuda::CUDAStream stream)
379+
void recordStream(void* ptr, cuda::CUDAStream stream, bool suppressError=false)
380380
{
381381
// Empty tensor's storage().data() might be a null ptr. As there is no
382382
// blocks associated with those tensors, it is fine to do nothing here.
383383
if (ptr) {
384384
std::lock_guard<std::recursive_mutex> lock(mutex);
385385
Block* block = find_allocated_block(ptr);
386386
if (!block) {
387-
AT_ERROR("invalid device pointer: ", ptr);
388-
}
389-
if (stream.stream() == block->stream) {
390-
// ignore uses on the allocation stream, since those don't require any
391-
// special synchronization
392-
return;
387+
// In some cases (e.g., tensor loaded from blob, or shared by another
388+
// process), this CUDACachingAllocator does not know about the ptr,
389+
// and the caller of this function might not have enough context to
390+
// check where the tensor is originated. One option is to expose a new
391+
// API from CUDACachingAllocator to check whether it knows about the
392+
// ptr, but it would force other use cases to unnecessarily do two
393+
// map look up (one check + one recordStream). Hence, we provide a
394+
// suppressError argument to avoid error and two lookups.
395+
if (!suppressError) {
396+
AT_ERROR("invalid device pointer: ", ptr);
397+
}
398+
} else {
399+
if (stream.stream() == block->stream) {
400+
// ignore uses on the allocation stream, since those don't require any
401+
// special synchronization
402+
return;
403+
}
404+
block->stream_uses.insert(stream);
393405
}
394-
block->stream_uses.insert(stream);
395406
}
396407
}
397408

@@ -651,9 +662,9 @@ void* getBaseAllocation(void *ptr, size_t *size)
651662
return caching_allocator.getBaseAllocation(ptr, size);
652663
}
653664

654-
void recordStream(void *ptr, cuda::CUDAStream stream)
665+
void recordStream(void *ptr, cuda::CUDAStream stream, bool suppressError)
655666
{
656-
caching_allocator.recordStream(ptr, stream);
667+
caching_allocator.recordStream(ptr, stream, suppressError);
657668
}
658669

659670
std::mutex* getFreeMutex()

c10/cuda/CUDACachingAllocator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ C10_CUDA_API Allocator* get();
4646
C10_CUDA_API void emptyCache();
4747
C10_CUDA_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock);
4848
C10_CUDA_API void* getBaseAllocation(void *ptr, size_t *size);
49-
C10_CUDA_API void recordStream(void *ptr, CUDAStream stream);
49+
C10_CUDA_API void recordStream(void *ptr, CUDAStream stream, bool suppressError=false);
5050
C10_CUDA_API uint64_t currentMemoryAllocated(int device);
5151
C10_CUDA_API uint64_t maxMemoryAllocated(int device);
5252
C10_CUDA_API void resetMaxMemoryAllocated(int device);

test/test_c10d.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
import torch.nn.functional as F
2121
import torch.distributed as c10d
2222
import torch.distributed as dist
23+
import torch.multiprocessing as mp
2324
from torch.nn.parallel import DistributedDataParallel
2425

25-
from common_utils import TestCase, load_tests, run_tests
26+
from common_utils import TestCase, load_tests, run_tests, PY3
2627
from common_utils import retry_on_address_already_in_use_error
2728

2829
# load_tests from common_utils is used to automatically filter tests for
@@ -1606,6 +1607,54 @@ def allreduce(tensors):
16061607
tensors_list[i - 2][j])
16071608

16081609

1610+
class ProcessGroupShareTensorTest(TestCase):
1611+
1612+
@property
1613+
def world_size(self):
1614+
return 2
1615+
1616+
def opts(threads=2):
1617+
opts = c10d.ProcessGroupGloo.Options()
1618+
opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
1619+
opts.timeout = 5.0
1620+
opts.threads = threads
1621+
return opts
1622+
1623+
def _test_allreduce_gloo_process(rank, filename, shared_tensors, world_size):
1624+
store = c10d.FileStore(filename, world_size)
1625+
pg = c10d.ProcessGroupGloo(
1626+
store, rank, world_size, ProcessGroupShareTensorTest.opts())
1627+
xs = [shared_tensors[rank]]
1628+
pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait()
1629+
xs[0].to('cpu').allclose(torch.ones(2, 2))
1630+
1631+
@unittest.skipIf(not PY3, "Python 3 needed")
1632+
@skip_if_not_multigpu
1633+
def test_allreduce_gloo(self):
1634+
file = tempfile.NamedTemporaryFile(delete=False)
1635+
shared_tensors = [torch.ones(2, 2).to(i).share_memory_() for i in range(2)]
1636+
mp.spawn(ProcessGroupShareTensorTest._test_allreduce_gloo_process,
1637+
args=(file.name, shared_tensors, self.world_size),
1638+
nprocs=self.world_size,
1639+
join=True)
1640+
1641+
def _test_allreduce_nccl_process(rank, filename, shared_tensors, world_size):
1642+
store = c10d.FileStore(filename, world_size)
1643+
pg = c10d.ProcessGroupNCCL(store, rank, world_size)
1644+
xs = [shared_tensors[rank]]
1645+
pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait()
1646+
xs[0].to('cpu').allclose(torch.ones(2, 2))
1647+
1648+
@unittest.skipIf(not PY3, "Python 3 needed")
1649+
@skip_if_not_multigpu
1650+
def test_allreduce_nccl(self):
1651+
file = tempfile.NamedTemporaryFile(delete=False)
1652+
shared_tensors = [torch.ones(2, 2).to(i).share_memory_() for i in range(2)]
1653+
mp.spawn(ProcessGroupShareTensorTest._test_allreduce_gloo_process,
1654+
args=(file.name, shared_tensors, self.world_size),
1655+
nprocs=self.world_size,
1656+
join=True)
1657+
16091658
class Net(nn.Module):
16101659
def __init__(self):
16111660
super(Net, self).__init__()

torch/lib/c10d/ProcessGroupGloo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ void initializeStreamsEvents(
161161
// `tensors` are created on a different stream. Hence, they must record
162162
// new streams in this Work to prevent being freed before the Work finishes.
163163
c10::cuda::CUDACachingAllocator::recordStream(
164-
tensors[i].storage().data(), streams[i]);
164+
tensors[i].storage().data(), streams[i], true);
165165
}
166166
}
167167

@@ -205,7 +205,7 @@ void initializeStreamsEvents(
205205
// new streams in this Work to prevent being freed before the Work
206206
// finishes.
207207
c10::cuda::CUDACachingAllocator::recordStream(
208-
tensor.storage().data(), streams[i]);
208+
tensor.storage().data(), streams[i], true);
209209
}
210210
}
211211
}

torch/lib/c10d/ProcessGroupNCCL.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
414414
//
415415
// See [Sync Streams].
416416
c10::cuda::CUDACachingAllocator::recordStream(
417-
inputs[i].storage().data(), ncclStream);
417+
inputs[i].storage().data(), ncclStream, true);
418418

419419
C10D_NCCL_CHECK(fn(
420420
inputs[i],
@@ -529,7 +529,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
529529
[&] (at::Tensor& input, at::Tensor& output,
530530
ncclComm_t comm, at::cuda::CUDAStream& stream) {
531531
c10::cuda::CUDACachingAllocator::recordStream(
532-
output.storage().data(), stream
532+
output.storage().data(), stream, true
533533
);
534534
return ncclAllGather(
535535
input.data_ptr(),
@@ -548,7 +548,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
548548
for (size_t j = 0; j < outputTensors[0].size(); ++j) {
549549
// See [Sync Streams].
550550
c10::cuda::CUDACachingAllocator::recordStream(
551-
outputTensors[i][j].storage().data(), ncclStreams[i]);
551+
outputTensors[i][j].storage().data(), ncclStreams[i], true);
552552

553553
outputTensors[i][j].copy_(outputFlattened[i][j], true);
554554
}
@@ -572,7 +572,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
572572
[&] (at::Tensor& input, at::Tensor& output,
573573
ncclComm_t comm, at::cuda::CUDAStream& stream) {
574574
c10::cuda::CUDACachingAllocator::recordStream(
575-
output.storage().data(), stream
575+
output.storage().data(), stream, true
576576
);
577577
return ncclReduceScatter(
578578
input.data_ptr(),
@@ -591,7 +591,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
591591
for (size_t j = 0; j < inputTensors[0].size(); ++j) {
592592
// See [Sync Streams].
593593
c10::cuda::CUDACachingAllocator::recordStream(
594-
inputTensors[i][j].storage().data(), ncclStreams[i]);
594+
inputTensors[i][j].storage().data(), ncclStreams[i], true);
595595

596596
inputFlattened[i][j].copy_(inputTensors[i][j], true);
597597
}

0 commit comments

Comments
 (0)