Skip to content

Commit c9f41e9

Browse files
pieternfacebook-github-bot
authored andcommitted
Add device guard around MPI operations (#22446)
Summary: If the current CUDA device is not the same as the device that hosts the tensor the operation works on then OpenMPI will segfault, as reported in #21922. This changes adds a device guard for every operation to ensure the correct device is set. Fixes #21922. Pull Request resolved: #22446 Differential Revision: D16106823 Pulled By: pietern fbshipit-source-id: 99d762eb3851c0a0e0b4fe81cf27c1c8d35596cc
1 parent abb2e68 commit c9f41e9

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torch/lib/c10d/ProcessGroupMPI.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <map>
44

5+
#include <c10/core/DeviceGuard.h>
6+
57
#if defined(OPEN_MPI) && OPEN_MPI
68
#include <mpi-ext.h> // Needed for CUDA-aware check
79
#endif
@@ -316,6 +318,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::broadcast(
316318
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
317319
[opts, this](std::unique_ptr<WorkEntry>& entry) {
318320
auto data = (entry->src)[0];
321+
c10::DeviceGuard guard(data.device());
319322
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
320323
MPI_CHECK(MPI_Bcast(
321324
data.data_ptr(),
@@ -337,6 +340,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce(
337340
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
338341
[opts, this](std::unique_ptr<WorkEntry>& entry) {
339342
auto data = (entry->src)[0];
343+
c10::DeviceGuard guard(data.device());
340344
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
341345
MPI_CHECK(MPI_Allreduce(
342346
MPI_IN_PLACE,
@@ -363,6 +367,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::reduce(
363367
void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
364368
void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr;
365369

370+
c10::DeviceGuard guard(data.device());
366371
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
367372
MPI_CHECK(MPI_Reduce(
368373
sendbuf,
@@ -402,6 +407,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allgather(
402407
std::vector<at::Tensor>& outputDataVec = entry->dst;
403408
auto flatOutputTensor = newLikeFlat(outputDataVec);
404409

410+
c10::DeviceGuard guard(data.device());
405411
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
406412
MPI_CHECK(MPI_Allgather(
407413
data.data_ptr(),
@@ -456,6 +462,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
456462
recvbuf = flatOutputTensor.data_ptr();
457463
}
458464

465+
c10::DeviceGuard guard(data.device());
459466
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
460467
MPI_CHECK(MPI_Gather(
461468
data.data_ptr(),
@@ -529,6 +536,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::scatter(
529536
}
530537
}
531538

539+
c10::DeviceGuard guard(data.device());
532540
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
533541
MPI_CHECK(MPI_Scatter(
534542
sendbuf,
@@ -569,6 +577,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::send(
569577
MPI_Request request = MPI_REQUEST_NULL;
570578

571579
{
580+
c10::DeviceGuard guard(tensor.device());
572581
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
573582
MPI_CHECK(MPI_Isend(
574583
tensor.data_ptr(),
@@ -593,6 +602,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recv(
593602
MPI_Request request = MPI_REQUEST_NULL;
594603

595604
{
605+
c10::DeviceGuard guard(tensor.device());
596606
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
597607
MPI_CHECK(MPI_Irecv(
598608
tensor.data_ptr(),
@@ -616,6 +626,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recvAnysource(
616626
MPI_Request request = MPI_REQUEST_NULL;
617627

618628
{
629+
c10::DeviceGuard guard(tensor.device());
619630
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
620631
MPI_CHECK(MPI_Irecv(
621632
tensor.data_ptr(),

0 commit comments

Comments
 (0)