-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[c10d] MPI Process Group Implementation #7783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a7ae7c3
2d228ec
0f2296b
effc473
d14497a
6a23c7e
3bee9f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,293 @@ | ||
| #include "ProcessGroupMPI.hpp" | ||
|
|
||
| #include <mpi-ext.h> // Needed for CUDA-aware check | ||
| #include <map> | ||
|
|
||
| namespace c10d { | ||
|
|
||
| #define MPI_CHECK(cmd) \ | ||
| do { \ | ||
| int mpiStatus = cmd; \ | ||
| if (mpiStatus != MPI_SUCCESS) { \ | ||
| std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \ | ||
| std::to_string(__LINE__) + \ | ||
| ", with error code: " + std::to_string(mpiStatus); \ | ||
| throw std::runtime_error(err); \ | ||
| } \ | ||
| } while (0) | ||
|
|
||
| namespace { | ||
|
|
||
| // Op mapping | ||
| std::map<ReduceOp, MPI_Op> mpiOp = { | ||
| {ReduceOp::MIN, MPI_MIN}, | ||
| {ReduceOp::MAX, MPI_MAX}, | ||
| {ReduceOp::SUM, MPI_SUM}, | ||
| {ReduceOp::PRODUCT, MPI_PROD}, | ||
| }; | ||
| // Type mapping | ||
| std::map<at::ScalarType, MPI_Datatype> mpiDatatype = { | ||
| {at::kByte, MPI_UNSIGNED_CHAR}, | ||
| {at::kChar, MPI_CHAR}, | ||
| {at::kDouble, MPI_DOUBLE}, | ||
| {at::kFloat, MPI_FLOAT}, | ||
| {at::kInt, MPI_INT}, | ||
| {at::kLong, MPI_LONG}, | ||
| {at::kShort, MPI_SHORT}, | ||
| }; | ||
|
|
||
| // Checking CUDA-aware MPI support | ||
| bool cudaAwareMpiCheck() { | ||
| // Run time check | ||
| #if defined(MPIX_CUDA_AWARE_SUPPORT) | ||
| if (MPIX_Query_cuda_support() == 1) { | ||
| return true; | ||
| } else { | ||
| return false; | ||
| } | ||
| #else // !defined(MPIX_CUDA_AWARE_SUPPORT) | ||
| return false; | ||
| #endif // MPIX_CUDA_AWARE_SUPPORT | ||
| } | ||
|
|
||
| // Checking the input tensor's validity | ||
| void checkSingleTensor(const std::vector<at::Tensor>& tensors) { | ||
| if (tensors.size() != 1) { | ||
| throw std::runtime_error( | ||
| "MPI process group only supports a single " | ||
| "tensor op"); | ||
| } | ||
| if (!tensors[0].is_contiguous()) { | ||
| throw std::runtime_error("input tensor has to be contiguous"); | ||
| } | ||
| if (tensors[0].is_cuda() && !cudaAwareMpiCheck()) { | ||
| throw std::runtime_error( | ||
| "CUDA tensor detected and the MPI used doesn't " | ||
| "have CUDA-aware MPI support"); | ||
| } | ||
| } | ||
|
|
||
| void mpiExit() { | ||
| MPI_CHECK(MPI_Finalize()); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| // ProcessGroupMPI::WorkMPI | ||
| ProcessGroupMPI::WorkMPI::WorkMPI() : completed_(false) {} | ||
|
|
||
| ProcessGroupMPI::WorkMPI::~WorkMPI() {} | ||
|
|
||
| bool ProcessGroupMPI::WorkMPI::isCompleted() const { | ||
| return completed_; | ||
| } | ||
|
|
||
| bool ProcessGroupMPI::WorkMPI::isSuccess() const { | ||
| return !workException_; | ||
| } | ||
|
|
||
| bool ProcessGroupMPI::WorkMPI::wait() { | ||
| std::unique_lock<std::mutex> lock(workMutex_); | ||
| while (!completed_) { | ||
| workCV_.wait(lock); | ||
| } | ||
| return isSuccess(); | ||
| } | ||
|
|
||
| void ProcessGroupMPI::WorkMPI::finish() { | ||
| { | ||
| std::unique_lock<std::mutex> lock(workMutex_); | ||
| completed_ = true; | ||
| } | ||
| workCV_.notify_all(); | ||
| } | ||
|
|
||
| void ProcessGroupMPI::WorkMPI::finishWithException( | ||
| std::exception_ptr caughtWorkException) { | ||
| { | ||
| std::unique_lock<std::mutex> lock(workMutex_); | ||
| completed_ = true; | ||
| workException_ = caughtWorkException; | ||
| } | ||
| workCV_.notify_all(); | ||
| } | ||
|
|
||
| const std::exception& ProcessGroupMPI::WorkMPI::exception() const { | ||
| try { | ||
| std::rethrow_exception(workException_); | ||
| } catch (const std::exception& e) { | ||
| return e; | ||
| } | ||
| } | ||
|
|
||
| // Static global states | ||
| int ProcessGroupMPI::numProcessGroups_ = 0; | ||
| int ProcessGroupMPI::mpiThreadSupport_ = 0; | ||
| std::mutex ProcessGroupMPI::pgGlobalMutex_; | ||
| // We only want to initialize once | ||
| std::once_flag ProcessGroupMPI::onceFlagInitMPI; | ||
|
|
||
| void ProcessGroupMPI::initMPIOnce() { | ||
| // Initialize MPI environment | ||
| std::call_once(onceFlagInitMPI, []() { | ||
| MPI_CHECK(MPI_Init_thread( | ||
| nullptr, nullptr, MPI_THREAD_MULTIPLE, &mpiThreadSupport_)); | ||
| if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) { | ||
| throw std::runtime_error( | ||
| "Used MPI implementation doesn't have the " | ||
| "minimum level of threading support: " | ||
| "MPI_THREAD_SERIALIZED. This is required by " | ||
| "c10d package"); | ||
| } | ||
| if (std::atexit(mpiExit)) { | ||
| throw std::runtime_error("Fail to register the MPI exit handler"); | ||
| } | ||
| }); | ||
| } | ||
|
|
||
| std::shared_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI() { | ||
| // Once initialization | ||
| initMPIOnce(); | ||
|
|
||
| int rank = -1; | ||
| int size = -1; | ||
| // Update the world size and rank | ||
| MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &size)); | ||
| MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); | ||
|
|
||
| if (rank < 0 || size < 0) { | ||
| throw std::runtime_error("Failed to get the world_size / rank"); | ||
| } | ||
|
|
||
| return std::make_shared<ProcessGroupMPI>(rank, size); | ||
| } | ||
|
|
||
| ProcessGroupMPI::ProcessGroupMPI(int rank, int size) | ||
| : ProcessGroup(rank, size), stop_(false) { | ||
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); | ||
|
|
||
| if (mpiThreadSupport_ != MPI_THREAD_MULTIPLE && numProcessGroups_ >= 1) { | ||
| throw std::runtime_error( | ||
| "More than one process group created, " | ||
| "this is not supported due to the used MPI " | ||
| "implementation doesn't provide the full support " | ||
| "of multi-threading"); | ||
| } | ||
| // increase the total PG count | ||
| ++numProcessGroups_; | ||
| globalLock.unlock(); | ||
|
|
||
| // Start the worker thread accepting MPI calls | ||
| workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this); | ||
| } | ||
|
|
||
| ProcessGroupMPI::~ProcessGroupMPI() { | ||
| destroy(); | ||
| } | ||
|
|
||
| void ProcessGroupMPI::destroy() { | ||
| std::unique_lock<std::mutex> lock(pgMutex_); | ||
|
|
||
| while (!queue_.empty()) { | ||
| queueConsumeCV_.wait(lock); | ||
| } | ||
| // Queue is empty, signal stop | ||
| stop_ = true; | ||
|
|
||
| // Release lock to allow threads to terminate | ||
| queueProduceCV_.notify_all(); | ||
|
|
||
| lock.unlock(); | ||
|
|
||
| // Join the single worker thread | ||
| workerThread_.join(); | ||
|
|
||
| // Decrease the number of PG created | ||
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); | ||
| --numProcessGroups_; | ||
| } | ||
|
|
||
| void ProcessGroupMPI::abort() { | ||
| destroy(); | ||
| MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); | ||
| } | ||
|
|
||
| void ProcessGroupMPI::runLoop() { | ||
| std::unique_lock<std::mutex> lock(pgMutex_); | ||
|
|
||
| while (!stop_) { | ||
| if (queue_.empty()) { | ||
| queueProduceCV_.wait(lock); | ||
| continue; | ||
| } | ||
|
|
||
| auto workTuple = std::move(queue_.front()); | ||
|
|
||
| queue_.pop_front(); | ||
| queueConsumeCV_.notify_one(); | ||
|
|
||
| auto& workEntry = std::get<0>(workTuple); | ||
| auto& work = std::get<1>(workTuple); | ||
|
|
||
| lock.unlock(); | ||
|
|
||
| try { | ||
| workEntry->run(workEntry); | ||
| work->finish(); | ||
| } catch (...) { | ||
| work->finishWithException(std::current_exception()); | ||
| } | ||
|
|
||
| lock.lock(); | ||
| } | ||
| } | ||
|
|
||
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::enqueue( | ||
| std::unique_ptr<WorkEntry> entry) { | ||
| auto work = std::make_shared<WorkMPI>(); | ||
| std::unique_lock<std::mutex> lock(pgMutex_); | ||
| queue_.push_back(std::make_tuple(std::move(entry), work)); | ||
| queueProduceCV_.notify_one(); | ||
| return work; | ||
| } | ||
|
|
||
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::broadcast( | ||
| std::vector<at::Tensor>& tensors, | ||
| const BroadcastOptions& opts) { | ||
| checkSingleTensor(tensors); | ||
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = | ||
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
Sorry, something went wrong. |
||
| [opts](std::unique_ptr<WorkEntry>& entry) { | ||
| auto data = (*entry->src)[0]; | ||
| MPI_CHECK(MPI_Bcast( | ||
| data.data_ptr(), | ||
| data.numel(), | ||
| mpiDatatype.at(data.type().scalarType()), | ||
| opts.rootRank, | ||
| MPI_COMM_WORLD)); | ||
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
Sorry, something went wrong. |
||
| }; | ||
| auto entry = std::unique_ptr<WorkEntry>( | ||
| new WorkEntry(&tensors, nullptr, std::move(runFunc))); | ||
| return enqueue(std::move(entry)); | ||
| } | ||
|
|
||
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce( | ||
| std::vector<at::Tensor>& tensors, | ||
| const AllreduceOptions& opts) { | ||
| checkSingleTensor(tensors); | ||
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = | ||
| [opts](std::unique_ptr<WorkEntry>& entry) { | ||
| auto data = (*entry->src)[0]; | ||
| MPI_CHECK(MPI_Allreduce( | ||
| MPI_IN_PLACE, | ||
| data.data_ptr(), | ||
| data.numel(), | ||
| mpiDatatype.at(data.type().scalarType()), | ||
| mpiOp.at(opts.reduceOp), | ||
| MPI_COMM_WORLD)); | ||
| }; | ||
| auto entry = std::unique_ptr<WorkEntry>( | ||
| new WorkEntry(&tensors, nullptr, std::move(runFunc))); | ||
| return enqueue(std::move(entry)); | ||
| } | ||
|
|
||
| } // namespace c10d | ||
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.