Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,8 @@ class ProcessGroupNCCL(ProcessGroup):
size: int,
timeout: timedelta,
): ...
@staticmethod
def _group_start() -> None: ...
@staticmethod
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing static method as _group_end() might need to check the communicator map of the ProcessGroup to properly wait on collectives if nonblocking is used.

def _group_end() -> None: ...
def _group_start(self) -> None: ...
def _group_end(self) -> None: ...

class ProcessGroupUCC(ProcessGroup):
def __init__(
Expand Down
184 changes: 181 additions & 3 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
#include <type_traits>
#include <unordered_map>

#if !defined(USE_ROCM) && \
((NCCL_MACJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14)))
#define NCCL_HAS_COMM_NONBLOCKING 1
#endif

ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
return reinterpret_cast<ncclComm_t*>(var);
}
Expand Down Expand Up @@ -44,6 +49,10 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
return ncclResult_t::ncclInvalidUsage;
case torch::cuda::nccl::ncclResult::NumResults:
return ncclResult_t::ncclNumResults;
#ifdef NCCL_HAS_COMM_NONBLOCKING
case torch::cuda::nccl::ncclResult::InProgress:
return ncclResult_t::ncclInProgress;
#endif
default:
throw std::runtime_error("Unconvertible NCCL type");
}
Expand All @@ -65,6 +74,10 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
return torch::cuda::nccl::ncclResult::InvalidUsage;
case ncclNumResults:
return torch::cuda::nccl::ncclResult::NumResults;
#ifdef NCCL_HAS_COMM_NONBLOCKING
case ncclInProgress:
return torch::cuda::nccl::ncclResult::InProgress;
#endif
default:
throw std::runtime_error("Unconvertible NCCL type");
}
Expand Down Expand Up @@ -123,6 +136,105 @@ static inline void NCCL_CHECK(ncclResult_t result) {
NCCL_CHECK(from_nccl_result(result));
}

// TODO(eqy): can this duplication be avoided from NCCLUtils.cpp?
bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
if (nccl_use_nonblocking_) {
TORCH_WARN("Using experimental non-blocking NCCL communicator.");
}
return nccl_use_nonblocking_;
}

static int _parse_nccl_nonblocking_timeout() {
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
int timeout = -1;
if (val) {
const std::string config(val);
timeout = std::stoi(config);
if (!nccl_use_nonblocking() && timeout > 0) {
TORCH_WARN(
"TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
timeout = -1;
}
}
return timeout;
}

static int nccl_nonblocking_timeout() {
static int timeout = _parse_nccl_nonblocking_timeout();
return timeout;
}

static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
#ifdef NCCL_HAS_COMM_NONBLOCKING
ncclResult_t result = to_nccl_result(status);
auto startTimepoint = std::chrono::steady_clock::now();
while (result == ncclInProgress) {
if (nccl_nonblocking_timeout() > 0) {
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
}
}
ncclCommGetAsyncError(to_nccl_comm(comm), &result);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if to_nccl_comm(comm) is needed here.

Here is definition of to_nccl_comm:

ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) {
  return reinterpret_cast<ncclComm_t>(var);
}

It seems to me comm is already a ncclComm_t (the one defined by NCCL).

Side note:
We should remove the duplicated ncclComm_t definition in torch::cuda::nccl. It is making things complicated.
It is out of scope of this PR. We can do that later.

}
if (result != ncclSuccess) {
throw_nccl_error(from_nccl_result(result));
}
#else
TORCH_INTERNAL_ASSERT(
false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
#endif
}

static inline void NCCL_CHECK_TIMEOUT(ncclResult_t result, ncclComm_t comm) {
NCCL_CHECK_TIMEOUT(from_nccl_result(result), comm);
}

static inline void NCCL_CHECK_TIMEOUT(
ncclResult status,
std::vector<ncclComm_t>& comms) {
#ifdef NCCL_HAS_COMM_NONBLOCKING
ncclResult_t result = to_nccl_result(status);
auto startTimepoint = std::chrono::steady_clock::now();
if (result == ncclInProgress) {
for (const auto i : c10::irange(comms.size())) {
do {
if (nccl_nonblocking_timeout() > 0) {
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error("NCCL timeout.");
}
}
ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);
} while (result == ncclInProgress);
if (result != ncclSuccess) {
break; /* fall through to failed case */
}
}
}
if (result != ncclSuccess) {
throw_nccl_error(from_nccl_result(result));
}
#else
TORCH_INTERNAL_ASSERT(
false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION.");
#endif
}

static inline void NCCL_CHECK_TIMEOUT(
ncclResult_t result,
std::vector<ncclComm_t>& comms) {
NCCL_CHECK_TIMEOUT(from_nccl_result(result), comms);
}

void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
std::ostringstream err;
err << "NCCL Error " << static_cast<int>(status) << ": "
Expand Down Expand Up @@ -308,9 +420,28 @@ AutoNcclGroup::AutoNcclGroup() {
#endif
}

AutoNcclGroup::AutoNcclGroup(
std::vector<ncclComm_t>& comms,
bool comm_nonblocking) {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
// nccl < 2.0 cannot be called concurrently with cudaFree
(c10::cuda::getFreeMutex())->lock();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is not needed anymore.
Reminder for myself.

#endif
// TODO(eqy): can we make comms_ reference?
comms_ = comms;
comm_nonblocking_ = comm_nonblocking;
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
detail::NCCL_CHECK(ncclGroupStart());
#endif
}

AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
detail::NCCL_CHECK(ncclGroupEnd());
if (!comm_nonblocking_) {
detail::NCCL_CHECK(ncclGroupEnd());
} else {
detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comms_);
}
#endif
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
(c10::cuda::getFreeMutex())->unlock();
Expand Down Expand Up @@ -677,7 +808,11 @@ void all2all_single_equal_split(
ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
}
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#endif
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
Expand Down Expand Up @@ -730,7 +865,11 @@ void all2all_single_unequal_split(
stream));
}
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
Expand Down Expand Up @@ -773,7 +912,11 @@ void all2all(
stream.stream()));
}
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
Expand All @@ -791,13 +934,25 @@ void send(
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclSend(
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
dst,
to_nccl_comm(comm),
stream.stream()));
#else
NCCL_CHECK_TIMEOUT(
ncclSend(
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
dst,
to_nccl_comm(comm),
stream.stream()),
comm);
#endif
#else
AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0");
#endif
Expand All @@ -815,13 +970,25 @@ void recv(
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclRecv(
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
src,
to_nccl_comm(comm),
stream.stream()));
#else
NCCL_CHECK_TIMEOUT(
ncclRecv(
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
src,
to_nccl_comm(comm),
stream.stream()),
comm);
#endif
#else
AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0");
#endif
Expand Down Expand Up @@ -865,7 +1032,11 @@ void gather(
} else {
NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream));
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());
#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif

#else
AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0");
Expand All @@ -888,9 +1059,13 @@ void scatter(

auto comm = to_nccl_comm(_comm);
int numranks, cur_rank;
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));

#else
NCCL_CHECK_TIMEOUT(ncclCommCount(comm, &numranks), _comm);
NCCL_CHECK_TIMEOUT(ncclCommUserRank(comm, &cur_rank), _comm);
#endif
NCCL_CHECK(ncclGroupStart());
if (cur_rank == root) {
for (const auto r : c10::irange(numranks)) {
Expand All @@ -910,8 +1085,11 @@ void scatter(
auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
NCCL_CHECK(ncclGroupEnd());

#else
NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm);
#endif
#else
AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
#endif
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/cuda/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ enum class ncclResult {
InternalError = 3,
InvalidArgument = 4,
InvalidUsage = 5,
NumResults = 6
NumResults = 6,
InProgress = 7
Comment on lines -49 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since we are using enum class here, it would be indeed more flexible if we don't number those enums? A reminder for myself to fix it later.
Ah, actually, we should remove this duplicated class entirely.

};

/* Reduction operation selector */
Expand Down Expand Up @@ -77,7 +78,10 @@ enum class ncclDataType {
// manages group and lock lifetimes.
struct AutoNcclGroup {
AutoNcclGroup();
AutoNcclGroup(std::vector<ncclComm_t>& comms, bool comm_nonblocking);
~AutoNcclGroup() noexcept(false);
std::vector<ncclComm_t> comms_;
bool comm_nonblocking_;
};

// NOTE: this is exposed only so that python_nccl.cpp can some of these helpers.
Expand Down
30 changes: 30 additions & 0 deletions torch/csrc/distributed/c10d/NCCLUtils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>

#include <c10/util/CallOnce.h>
#include <c10/util/env.h>

#ifdef USE_C10D_NCCL

Expand Down Expand Up @@ -52,6 +53,35 @@ std::string getNcclVersion() {
return versionString;
}

bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
if (nccl_use_nonblocking_) {
TORCH_WARN("Using experimental non-blocking NCCL communicator.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: TORCH_INFO?

}
return nccl_use_nonblocking_;
}

int _parse_nccl_nonblocking_timeout() {
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
int timeout = -1;
if (val) {
const std::string config(val);
timeout = std::stoi(config);
if (!nccl_use_nonblocking() && timeout > 0) {
TORCH_WARN(
"TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
timeout = -1;
}
}
return timeout;
}

int nccl_nonblocking_timeout() {
static int timeout = _parse_nccl_nonblocking_timeout();
return timeout;
}

std::string ncclGetErrorWithVersion(ncclResult_t error) {
return std::string(ncclGetErrorString(error)) + ", NCCL version " +
getNcclVersion();
Expand Down
Loading