Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torch/lib/c10d/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ if(NOT Gloo_FOUND)
message(FATAL_ERROR "Gloo not found")
endif()

find_package(MPI)
if(MPI_FOUND)
MESSAGE(STATUS "MPI_INCLUDE_PATH: ${MPI_INCLUDE_PATH}")
MESSAGE(STATUS "MPI_LIBRARIES: ${MPI_LIBRARIES}")
MESSAGE(STATUS "MPIEXEC: ${MPIEXEC}")
else()
MESSAGE(STATUS "Not able to find MPI, will compile c10d without MPI support")
endif()

set(C10D_SRCS
Utils.cpp
Store.cpp
Expand All @@ -38,6 +47,15 @@ add_library(c10d_gloo ${C10D_GLOO_SRCS})
target_include_directories(c10d_gloo PUBLIC ${GLOO_INCLUDE_DIR})
target_link_libraries(c10d_gloo PUBLIC c10d ${GLOO_LIBRARIES})

if(MPI_FOUND)
set(C10D_MPI_SRCS
ProcessGroupMPI.cpp
)
add_library(c10d_mpi ${C10D_MPI_SRCS})
target_include_directories(c10d_mpi PUBLIC ${MPI_INCLUDE_PATH})
target_link_libraries(c10d_mpi PUBLIC c10d ${MPI_LIBRARIES})
endif()

add_subdirectory(example)

enable_testing()
Expand Down
293 changes: 293 additions & 0 deletions torch/lib/c10d/ProcessGroupMPI.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
#include "ProcessGroupMPI.hpp"

#include <mpi-ext.h> // Needed for CUDA-aware check
#include <map>

namespace c10d {

#define MPI_CHECK(cmd) \
do { \
int mpiStatus = cmd; \
if (mpiStatus != MPI_SUCCESS) { \
std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + \
", with error code: " + std::to_string(mpiStatus); \
throw std::runtime_error(err); \
} \
} while (0)

namespace {

// Op mapping
std::map<ReduceOp, MPI_Op> mpiOp = {
{ReduceOp::MIN, MPI_MIN},
{ReduceOp::MAX, MPI_MAX},
{ReduceOp::SUM, MPI_SUM},
{ReduceOp::PRODUCT, MPI_PROD},
};
// Type mapping
std::map<at::ScalarType, MPI_Datatype> mpiDatatype = {
{at::kByte, MPI_UNSIGNED_CHAR},
{at::kChar, MPI_CHAR},
{at::kDouble, MPI_DOUBLE},
{at::kFloat, MPI_FLOAT},
{at::kInt, MPI_INT},
{at::kLong, MPI_LONG},
{at::kShort, MPI_SHORT},
};

// Checking CUDA-aware MPI support
bool cudaAwareMpiCheck() {
// Run time check
#if defined(MPIX_CUDA_AWARE_SUPPORT)
if (MPIX_Query_cuda_support() == 1) {
return true;
} else {
return false;
}
#else // !defined(MPIX_CUDA_AWARE_SUPPORT)
return false;
#endif // MPIX_CUDA_AWARE_SUPPORT
}

// Checking the input tensor's validity
void checkSingleTensor(const std::vector<at::Tensor>& tensors) {
if (tensors.size() != 1) {
throw std::runtime_error(
"MPI process group only supports a single "
"tensor op");
}
if (!tensors[0].is_contiguous()) {
throw std::runtime_error("input tensor has to be contiguous");
}
if (tensors[0].is_cuda() && !cudaAwareMpiCheck()) {
throw std::runtime_error(
"CUDA tensor detected and the MPI used doesn't "
"have CUDA-aware MPI support");
}
}

void mpiExit() {
MPI_CHECK(MPI_Finalize());
}

} // namespace

// ProcessGroupMPI::WorkMPI
ProcessGroupMPI::WorkMPI::WorkMPI() : completed_(false) {}

ProcessGroupMPI::WorkMPI::~WorkMPI() {}

bool ProcessGroupMPI::WorkMPI::isCompleted() const {
return completed_;
}

bool ProcessGroupMPI::WorkMPI::isSuccess() const {
return !workException_;
}

bool ProcessGroupMPI::WorkMPI::wait() {
std::unique_lock<std::mutex> lock(workMutex_);
while (!completed_) {
workCV_.wait(lock);
}
return isSuccess();
}

void ProcessGroupMPI::WorkMPI::finish() {
{
std::unique_lock<std::mutex> lock(workMutex_);
completed_ = true;
}
workCV_.notify_all();
}

void ProcessGroupMPI::WorkMPI::finishWithException(
std::exception_ptr caughtWorkException) {
{
std::unique_lock<std::mutex> lock(workMutex_);
completed_ = true;
workException_ = caughtWorkException;
}
workCV_.notify_all();
}

const std::exception& ProcessGroupMPI::WorkMPI::exception() const {
try {
std::rethrow_exception(workException_);
} catch (const std::exception& e) {
return e;
}
}

// Static global states
int ProcessGroupMPI::numProcessGroups_ = 0;
int ProcessGroupMPI::mpiThreadSupport_ = 0;
std::mutex ProcessGroupMPI::pgGlobalMutex_;
// We only want to initialize once
std::once_flag ProcessGroupMPI::onceFlagInitMPI;

void ProcessGroupMPI::initMPIOnce() {
// Initialize MPI environment
std::call_once(onceFlagInitMPI, []() {
MPI_CHECK(MPI_Init_thread(
nullptr, nullptr, MPI_THREAD_MULTIPLE, &mpiThreadSupport_));
if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) {
throw std::runtime_error(
"Used MPI implementation doesn't have the "
"minimum level of threading support: "
"MPI_THREAD_SERIALIZED. This is required by "
"c10d package");
}
if (std::atexit(mpiExit)) {
throw std::runtime_error("Fail to register the MPI exit handler");
}
});
}

std::shared_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI() {
// Once initialization
initMPIOnce();

int rank = -1;
int size = -1;
// Update the world size and rank
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &size));
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank));

if (rank < 0 || size < 0) {
throw std::runtime_error("Failed to get the world_size / rank");
}

return std::make_shared<ProcessGroupMPI>(rank, size);
}

ProcessGroupMPI::ProcessGroupMPI(int rank, int size)
: ProcessGroup(rank, size), stop_(false) {
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);

if (mpiThreadSupport_ != MPI_THREAD_MULTIPLE && numProcessGroups_ >= 1) {
throw std::runtime_error(
"More than one process group created, "
"this is not supported due to the used MPI "
"implementation doesn't provide the full support "
"of multi-threading");
}
// increase the total PG count
++numProcessGroups_;
globalLock.unlock();

// Start the worker thread accepting MPI calls
workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this);
}

ProcessGroupMPI::~ProcessGroupMPI() {
destroy();
}

void ProcessGroupMPI::destroy() {
std::unique_lock<std::mutex> lock(pgMutex_);

while (!queue_.empty()) {
queueConsumeCV_.wait(lock);
}
// Queue is empty, signal stop
stop_ = true;

// Release lock to allow threads to terminate
queueProduceCV_.notify_all();

lock.unlock();

// Join the single worker thread
workerThread_.join();

// Decrease the number of PG created
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
--numProcessGroups_;

This comment was marked as off-topic.

This comment was marked as off-topic.

}

void ProcessGroupMPI::abort() {
destroy();
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
}

void ProcessGroupMPI::runLoop() {
std::unique_lock<std::mutex> lock(pgMutex_);

while (!stop_) {
if (queue_.empty()) {
queueProduceCV_.wait(lock);
continue;
}

auto workTuple = std::move(queue_.front());

queue_.pop_front();
queueConsumeCV_.notify_one();

auto& workEntry = std::get<0>(workTuple);
auto& work = std::get<1>(workTuple);

lock.unlock();

try {
workEntry->run(workEntry);
work->finish();
} catch (...) {
work->finishWithException(std::current_exception());
}

lock.lock();
}
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::enqueue(
std::unique_ptr<WorkEntry> entry) {
auto work = std::make_shared<WorkMPI>();
std::unique_lock<std::mutex> lock(pgMutex_);
queue_.push_back(std::make_tuple(std::move(entry), work));
queueProduceCV_.notify_one();
return work;
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
checkSingleTensor(tensors);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

[opts](std::unique_ptr<WorkEntry>& entry) {
auto data = (*entry->src)[0];
MPI_CHECK(MPI_Bcast(
data.data_ptr(),
data.numel(),
mpiDatatype.at(data.type().scalarType()),
opts.rootRank,
MPI_COMM_WORLD));

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

};
auto entry = std::unique_ptr<WorkEntry>(
new WorkEntry(&tensors, nullptr, std::move(runFunc)));
return enqueue(std::move(entry));
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
checkSingleTensor(tensors);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts](std::unique_ptr<WorkEntry>& entry) {
auto data = (*entry->src)[0];
MPI_CHECK(MPI_Allreduce(
MPI_IN_PLACE,
data.data_ptr(),
data.numel(),
mpiDatatype.at(data.type().scalarType()),
mpiOp.at(opts.reduceOp),
MPI_COMM_WORLD));
};
auto entry = std::unique_ptr<WorkEntry>(
new WorkEntry(&tensors, nullptr, std::move(runFunc)));
return enqueue(std::move(entry));
}

} // namespace c10d
Loading