22
33#include < map>
44
5+ #include < c10/core/DeviceGuard.h>
6+
57#if defined(OPEN_MPI) && OPEN_MPI
68#include < mpi-ext.h> // Needed for CUDA-aware check
79#endif
@@ -316,6 +318,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::broadcast(
316318 std::function<void (std::unique_ptr<WorkEntry>&)> runFunc =
317319 [opts, this ](std::unique_ptr<WorkEntry>& entry) {
318320 auto data = (entry->src )[0 ];
321+ c10::DeviceGuard guard (data.device ());
319322 std::unique_lock<std::mutex> globalLock (pgGlobalMutex_);
320323 MPI_CHECK (MPI_Bcast (
321324 data.data_ptr (),
@@ -337,6 +340,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce(
337340 std::function<void (std::unique_ptr<WorkEntry>&)> runFunc =
338341 [opts, this ](std::unique_ptr<WorkEntry>& entry) {
339342 auto data = (entry->src )[0 ];
343+ c10::DeviceGuard guard (data.device ());
340344 std::unique_lock<std::mutex> globalLock (pgGlobalMutex_);
341345 MPI_CHECK (MPI_Allreduce (
342346 MPI_IN_PLACE,
@@ -363,6 +367,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::reduce(
363367 void * sendbuf = (rank_ == opts.rootRank ) ? MPI_IN_PLACE : dataPtr;
364368 void * recvbuf = (rank_ == opts.rootRank ) ? dataPtr : nullptr ;
365369
370+ c10::DeviceGuard guard (data.device ());
366371 std::unique_lock<std::mutex> globalLock (pgGlobalMutex_);
367372 MPI_CHECK (MPI_Reduce (
368373 sendbuf,
@@ -402,6 +407,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allgather(
402407 std::vector<at::Tensor>& outputDataVec = entry->dst ;
403408 auto flatOutputTensor = newLikeFlat (outputDataVec);
404409
410+ c10::DeviceGuard guard (data.device ());
405411 std::unique_lock<std::mutex> globalLock (pgGlobalMutex_);
406412 MPI_CHECK (MPI_Allgather (
407413 data.data_ptr (),
@@ -456,6 +462,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
456462 recvbuf = flatOutputTensor.data_ptr ();
457463 }
458464
465+ c10::DeviceGuard guard (data.device ());
459466 std::unique_lock<std::mutex> globalLock (pgGlobalMutex_);
460467 MPI_CHECK (MPI_Gather (
461468 data.data_ptr (),
@@ -529,6 +536,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::scatter(
529536 }
530537 }
531538
539+ c10::DeviceGuard guard (data.device ());
532540 std::unique_lock<std::mutex> globalLock (pgGlobalMutex_);
533541 MPI_CHECK (MPI_Scatter (
534542 sendbuf,
@@ -569,6 +577,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::send(
569577 MPI_Request request = MPI_REQUEST_NULL;
570578
571579 {
580+ c10::DeviceGuard guard (tensor.device ());
572581 std::unique_lock<std::mutex> globalLock (pgGlobalMutex_);
573582 MPI_CHECK (MPI_Isend (
574583 tensor.data_ptr (),
@@ -593,6 +602,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recv(
593602 MPI_Request request = MPI_REQUEST_NULL;
594603
595604 {
605+ c10::DeviceGuard guard (tensor.device ());
596606 std::unique_lock<std::mutex> globalLock (pgGlobalMutex_);
597607 MPI_CHECK (MPI_Irecv (
598608 tensor.data_ptr (),
@@ -616,6 +626,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recvAnysource(
616626 MPI_Request request = MPI_REQUEST_NULL;
617627
618628 {
629+ c10::DeviceGuard guard (tensor.device ());
619630 std::unique_lock<std::mutex> globalLock (pgGlobalMutex_);
620631 MPI_CHECK (MPI_Irecv (
621632 tensor.data_ptr (),
0 commit comments