-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[NCCL] Add experimental Nonblocking NCCL Fault Tolerance/Checking #95715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f159947
0c31b7d
b7f5683
bbe086f
bbfa12a
eff18d0
6c5c5e0
15d2b30
3479ad4
6be4c95
4455b0f
cdafc1d
072e2c7
a61e802
5fa1842
71fbe6c
8dee33e
3ae5791
0cdb849
45a6c59
2afb09e
081b058
915cbbe
7cadf83
8b5c091
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,11 @@ | |
| #include <type_traits> | ||
| #include <unordered_map> | ||
|
|
||
| #if !defined(USE_ROCM) && \ | ||
| ((NCCL_MACJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 14))) | ||
| #define NCCL_HAS_COMM_NONBLOCKING 1 | ||
| #endif | ||
|
|
||
| ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) { | ||
| return reinterpret_cast<ncclComm_t*>(var); | ||
| } | ||
|
|
@@ -44,6 +49,10 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) { | |
| return ncclResult_t::ncclInvalidUsage; | ||
| case torch::cuda::nccl::ncclResult::NumResults: | ||
| return ncclResult_t::ncclNumResults; | ||
| #ifdef NCCL_HAS_COMM_NONBLOCKING | ||
| case torch::cuda::nccl::ncclResult::InProgress: | ||
| return ncclResult_t::ncclInProgress; | ||
| #endif | ||
| default: | ||
| throw std::runtime_error("Unconvertible NCCL type"); | ||
| } | ||
|
|
@@ -65,6 +74,10 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) { | |
| return torch::cuda::nccl::ncclResult::InvalidUsage; | ||
| case ncclNumResults: | ||
| return torch::cuda::nccl::ncclResult::NumResults; | ||
| #ifdef NCCL_HAS_COMM_NONBLOCKING | ||
| case ncclInProgress: | ||
| return torch::cuda::nccl::ncclResult::InProgress; | ||
| #endif | ||
| default: | ||
| throw std::runtime_error("Unconvertible NCCL type"); | ||
| } | ||
|
|
@@ -123,6 +136,105 @@ static inline void NCCL_CHECK(ncclResult_t result) { | |
| NCCL_CHECK(from_nccl_result(result)); | ||
| } | ||
|
|
||
| // TODO(eqy): can this duplication be avoided from NCCLUtils.cpp? | ||
| bool nccl_use_nonblocking() { | ||
| static bool nccl_use_nonblocking_ = | ||
| c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true; | ||
| if (nccl_use_nonblocking_) { | ||
| TORCH_WARN("Using experimental non-blocking NCCL communicator."); | ||
| } | ||
| return nccl_use_nonblocking_; | ||
| } | ||
|
|
||
| static int _parse_nccl_nonblocking_timeout() { | ||
| const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT"); | ||
| int timeout = -1; | ||
| if (val) { | ||
| const std::string config(val); | ||
| timeout = std::stoi(config); | ||
| if (!nccl_use_nonblocking() && timeout > 0) { | ||
| TORCH_WARN( | ||
| "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false."); | ||
| timeout = -1; | ||
| } | ||
| } | ||
| return timeout; | ||
| } | ||
|
|
||
| static int nccl_nonblocking_timeout() { | ||
| static int timeout = _parse_nccl_nonblocking_timeout(); | ||
| return timeout; | ||
| } | ||
|
|
||
| static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) { | ||
| #ifdef NCCL_HAS_COMM_NONBLOCKING | ||
| ncclResult_t result = to_nccl_result(status); | ||
| auto startTimepoint = std::chrono::steady_clock::now(); | ||
| while (result == ncclInProgress) { | ||
| if (nccl_nonblocking_timeout() > 0) { | ||
| auto currentTimepoint = std::chrono::steady_clock::now(); | ||
| auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( | ||
| currentTimepoint - startTimepoint) | ||
| .count(); | ||
| if (timeElapsed > nccl_nonblocking_timeout()) { | ||
| throw std::runtime_error("NCCL timeout."); | ||
| } | ||
| } | ||
| ncclCommGetAsyncError(to_nccl_comm(comm), &result); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if Here is definition of It seems to me Side note: |
||
| } | ||
| if (result != ncclSuccess) { | ||
| throw_nccl_error(from_nccl_result(result)); | ||
| } | ||
| #else | ||
| TORCH_INTERNAL_ASSERT( | ||
| false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION."); | ||
| #endif | ||
| } | ||
|
|
||
| static inline void NCCL_CHECK_TIMEOUT(ncclResult_t result, ncclComm_t comm) { | ||
| NCCL_CHECK_TIMEOUT(from_nccl_result(result), comm); | ||
| } | ||
|
|
||
| static inline void NCCL_CHECK_TIMEOUT( | ||
| ncclResult status, | ||
| std::vector<ncclComm_t>& comms) { | ||
| #ifdef NCCL_HAS_COMM_NONBLOCKING | ||
| ncclResult_t result = to_nccl_result(status); | ||
| auto startTimepoint = std::chrono::steady_clock::now(); | ||
| if (result == ncclInProgress) { | ||
| for (const auto i : c10::irange(comms.size())) { | ||
| do { | ||
| if (nccl_nonblocking_timeout() > 0) { | ||
| auto currentTimepoint = std::chrono::steady_clock::now(); | ||
| auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( | ||
| currentTimepoint - startTimepoint) | ||
| .count(); | ||
| if (timeElapsed > nccl_nonblocking_timeout()) { | ||
| throw std::runtime_error("NCCL timeout."); | ||
| } | ||
| } | ||
| ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result); | ||
| } while (result == ncclInProgress); | ||
| if (result != ncclSuccess) { | ||
| break; /* fall through to failed case */ | ||
| } | ||
| } | ||
| } | ||
| if (result != ncclSuccess) { | ||
| throw_nccl_error(from_nccl_result(result)); | ||
| } | ||
| #else | ||
| TORCH_INTERNAL_ASSERT( | ||
| false, "NCCL COMM NONBLOCKING USED WITH UNSUPPORTED NCCL VERSION."); | ||
| #endif | ||
| } | ||
|
|
||
| static inline void NCCL_CHECK_TIMEOUT( | ||
| ncclResult_t result, | ||
| std::vector<ncclComm_t>& comms) { | ||
| NCCL_CHECK_TIMEOUT(from_nccl_result(result), comms); | ||
| } | ||
|
|
||
| void throw_nccl_error(torch::cuda::nccl::ncclResult status) { | ||
| std::ostringstream err; | ||
| err << "NCCL Error " << static_cast<int>(status) << ": " | ||
|
|
@@ -308,9 +420,28 @@ AutoNcclGroup::AutoNcclGroup() { | |
| #endif | ||
| } | ||
|
|
||
| AutoNcclGroup::AutoNcclGroup( | ||
| std::vector<ncclComm_t>& comms, | ||
| bool comm_nonblocking) { | ||
| #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2) | ||
| // nccl < 2.0 cannot be called concurrently with cudaFree | ||
| (c10::cuda::getFreeMutex())->lock(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is not needed anymore. |
||
| #endif | ||
| // TODO(eqy): can we make comms_ reference? | ||
| comms_ = comms; | ||
| comm_nonblocking_ = comm_nonblocking; | ||
| #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) | ||
| detail::NCCL_CHECK(ncclGroupStart()); | ||
| #endif | ||
| } | ||
|
|
||
| AutoNcclGroup::~AutoNcclGroup() noexcept(false) { | ||
| #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) | ||
| detail::NCCL_CHECK(ncclGroupEnd()); | ||
| if (!comm_nonblocking_) { | ||
| detail::NCCL_CHECK(ncclGroupEnd()); | ||
| } else { | ||
| detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comms_); | ||
| } | ||
| #endif | ||
| #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2) | ||
| (c10::cuda::getFreeMutex())->unlock(); | ||
|
|
@@ -677,7 +808,11 @@ void all2all_single_equal_split( | |
| ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream)); | ||
| } | ||
| } | ||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||
| NCCL_CHECK(ncclGroupEnd()); | ||
| #else | ||
| NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); | ||
| #endif | ||
| #endif | ||
| #else | ||
| AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); | ||
|
|
@@ -730,7 +865,11 @@ void all2all_single_unequal_split( | |
| stream)); | ||
| } | ||
| } | ||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||
| NCCL_CHECK(ncclGroupEnd()); | ||
| #else | ||
| NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); | ||
| #endif | ||
| #else | ||
| AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); | ||
| #endif | ||
|
|
@@ -773,7 +912,11 @@ void all2all( | |
| stream.stream())); | ||
| } | ||
| } | ||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||
| NCCL_CHECK(ncclGroupEnd()); | ||
| #else | ||
| NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); | ||
| #endif | ||
| #else | ||
| AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); | ||
| #endif | ||
|
|
@@ -791,13 +934,25 @@ void send( | |
| #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ | ||
| (NCCL_MINOR >= 7) | ||
| using namespace torch::cuda::nccl::detail; | ||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||
| NCCL_CHECK(ncclSend( | ||
| input.data_ptr(), | ||
| input.numel(), | ||
| to_nccl_data_type(input), | ||
| dst, | ||
| to_nccl_comm(comm), | ||
| stream.stream())); | ||
| #else | ||
| NCCL_CHECK_TIMEOUT( | ||
| ncclSend( | ||
| input.data_ptr(), | ||
| input.numel(), | ||
| to_nccl_data_type(input), | ||
| dst, | ||
| to_nccl_comm(comm), | ||
| stream.stream()), | ||
| comm); | ||
| #endif | ||
| #else | ||
| AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0"); | ||
| #endif | ||
|
|
@@ -815,13 +970,25 @@ void recv( | |
| #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ | ||
| (NCCL_MINOR >= 7) | ||
| using namespace torch::cuda::nccl::detail; | ||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||
| NCCL_CHECK(ncclRecv( | ||
| output.data_ptr(), | ||
| output.numel(), | ||
| to_nccl_data_type(output), | ||
| src, | ||
| to_nccl_comm(comm), | ||
| stream.stream())); | ||
| #else | ||
| NCCL_CHECK_TIMEOUT( | ||
| ncclRecv( | ||
| output.data_ptr(), | ||
| output.numel(), | ||
| to_nccl_data_type(output), | ||
| src, | ||
| to_nccl_comm(comm), | ||
| stream.stream()), | ||
| comm); | ||
| #endif | ||
| #else | ||
| AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0"); | ||
| #endif | ||
|
|
@@ -865,7 +1032,11 @@ void gather( | |
| } else { | ||
| NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream)); | ||
| } | ||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||
| NCCL_CHECK(ncclGroupEnd()); | ||
| #else | ||
| NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); | ||
| #endif | ||
|
|
||
| #else | ||
| AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0"); | ||
|
|
@@ -888,9 +1059,13 @@ void scatter( | |
|
|
||
| auto comm = to_nccl_comm(_comm); | ||
| int numranks, cur_rank; | ||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||
| NCCL_CHECK(ncclCommCount(comm, &numranks)); | ||
| NCCL_CHECK(ncclCommUserRank(comm, &cur_rank)); | ||
|
|
||
| #else | ||
| NCCL_CHECK_TIMEOUT(ncclCommCount(comm, &numranks), _comm); | ||
| NCCL_CHECK_TIMEOUT(ncclCommUserRank(comm, &cur_rank), _comm); | ||
| #endif | ||
| NCCL_CHECK(ncclGroupStart()); | ||
| if (cur_rank == root) { | ||
| for (const auto r : c10::irange(numranks)) { | ||
|
|
@@ -910,8 +1085,11 @@ void scatter( | |
| auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr()); | ||
| NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream)); | ||
| } | ||
| #ifndef NCCL_HAS_COMM_NONBLOCKING | ||
| NCCL_CHECK(ncclGroupEnd()); | ||
|
|
||
| #else | ||
| NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); | ||
| #endif | ||
| #else | ||
| AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0"); | ||
| #endif | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,7 +46,8 @@ enum class ncclResult { | |
| InternalError = 3, | ||
| InvalidArgument = 4, | ||
| InvalidUsage = 5, | ||
| NumResults = 6 | ||
| NumResults = 6, | ||
| InProgress = 7 | ||
|
Comment on lines
-49
to
+50
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: since we are using |
||
| }; | ||
|
|
||
| /* Reduction operation selector */ | ||
|
|
@@ -77,7 +78,10 @@ enum class ncclDataType { | |
| // manages group and lock lifetimes. | ||
| struct AutoNcclGroup { | ||
| AutoNcclGroup(); | ||
| AutoNcclGroup(std::vector<ncclComm_t>& comms, bool comm_nonblocking); | ||
| ~AutoNcclGroup() noexcept(false); | ||
| std::vector<ncclComm_t> comms_; | ||
| bool comm_nonblocking_; | ||
| }; | ||
|
|
||
| // NOTE: this is exposed only so that python_nccl.cpp can some of these helpers. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| #include <torch/csrc/distributed/c10d/NCCLUtils.hpp> | ||
|
|
||
| #include <c10/util/CallOnce.h> | ||
| #include <c10/util/env.h> | ||
|
|
||
| #ifdef USE_C10D_NCCL | ||
|
|
||
|
|
@@ -52,6 +53,35 @@ std::string getNcclVersion() { | |
| return versionString; | ||
| } | ||
|
|
||
| bool nccl_use_nonblocking() { | ||
| static bool nccl_use_nonblocking_ = | ||
| c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true; | ||
| if (nccl_use_nonblocking_) { | ||
| TORCH_WARN("Using experimental non-blocking NCCL communicator."); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
| } | ||
| return nccl_use_nonblocking_; | ||
| } | ||
|
|
||
| int _parse_nccl_nonblocking_timeout() { | ||
| const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT"); | ||
| int timeout = -1; | ||
| if (val) { | ||
| const std::string config(val); | ||
| timeout = std::stoi(config); | ||
| if (!nccl_use_nonblocking() && timeout > 0) { | ||
| TORCH_WARN( | ||
| "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false."); | ||
| timeout = -1; | ||
| } | ||
| } | ||
| return timeout; | ||
| } | ||
|
|
||
| int nccl_nonblocking_timeout() { | ||
| static int timeout = _parse_nccl_nonblocking_timeout(); | ||
| return timeout; | ||
| } | ||
|
|
||
| std::string ncclGetErrorWithVersion(ncclResult_t error) { | ||
| return std::string(ncclGetErrorString(error)) + ", NCCL version " + | ||
| getNcclVersion(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing static method as
_group_end()might need to check the communicator map of theProcessGroupto properly wait on collectives if nonblocking is used.