Skip to content

Commit c83c709

Browse files
committed
[NCCL] Add Environment Variable to guard Async Error Handling feature
Pull Request resolved: #44163 In this PR, we introduce a new environment variable (NCCL_ASYNC_ERROR_HANDLING), which guards the asynchronous error handling feature. We intend to eventually turn this feature on by default for all users, but this is a temporary solution so the change in behavior from hanging to crashing is not the default for users all of a sudden. ghstack-source-id: 111637788 Differential Revision: [D23517895](https://our.internmc.facebook.com/intern/diff/D23517895/)
1 parent 0cf96f4 commit c83c709

File tree

2 files changed

+59
-19
lines changed

2 files changed

+59
-19
lines changed

torch/lib/c10d/ProcessGroupNCCL.cpp

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,21 @@ void ProcessGroupNCCL::parseNcclBlockingWait() {
404404
}
405405
}
406406

407+
void ProcessGroupNCCL::parseNcclAsyncErrorHandling() {
408+
char* errorHandle = getenv(NCCL_ASYNC_ERROR_HANDLING);
409+
if (errorHandle != nullptr) {
410+
auto val = std::stoi(errorHandle);
411+
if (val == 1) {
412+
asyncErrorHandling_ = true;
413+
LOG(INFO) << "[Rank " << rank_ << "] NCCL Async Error Handling enabled.";
414+
} else if (val != 0) {
415+
throw std::runtime_error(
416+
"Invalid value for environment variable: " +
417+
std::string(NCCL_ASYNC_ERROR_HANDLING));
418+
}
419+
}
420+
}
421+
407422
bool ProcessGroupNCCL::WorkNCCL::timedOut() {
408423
auto currentTimepoint = std::chrono::steady_clock::now();
409424
return (
@@ -428,6 +443,13 @@ ProcessGroupNCCL::ProcessGroupNCCL(
428443
"Invalid value for environment variable: " +
429444
std::string(NCCL_BLOCKING_WAIT));
430445
}
446+
try {
447+
parseNcclAsyncErrorHandling();
448+
} catch (std::exception& e) {
449+
throw std::runtime_error(
450+
"Invalid value for environment variable: " +
451+
std::string(NCCL_ASYNC_ERROR_HANDLING));
452+
}
431453

432454
// If single-process single-device mode, WorkNCCL::getFuture is supported,
433455
// so get a dedicated stream for each device to run FutureNCCL then callbacks.
@@ -445,31 +467,36 @@ ProcessGroupNCCL::ProcessGroupNCCL(
445467
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
446468
#endif
447469

448-
workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this);
470+
if (asyncErrorHandling_) {
471+
workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this);
472+
}
449473
}
450474

451475
ProcessGroupNCCL::~ProcessGroupNCCL() {
452476
terminateProcessGroup_.store(true);
453477
watchdogCV_.notify_one();
454478
workListCV_.notify_one();
455479

456-
std::unique_lock<std::mutex> lock(workListMutex_);
457-
// TODO: We can potentially merge this functionality into the workCleanup
458-
// thread or just allow the destructor to free workList_.
459-
// Clean up any remaining items in the workList_ instead of waiting for the
460-
// workCleanup Thread to be scheduled again.
461-
for (auto it = workList_.begin(); it != workList_.end();
462-
/* no increment*/) {
463-
auto& work = *it;
464-
if (work->isCompleted()) {
465-
it = workList_.erase(it);
466-
} else {
467-
++it;
480+
if (asyncErrorHandling_) {
481+
std::unique_lock<std::mutex> lock(workListMutex_);
482+
// TODO: We can potentially merge this functionality into the workCleanup
483+
// thread or just allow the destructor to free workList_.
484+
// Clean up any remaining items in the workList_ instead of waiting for the
485+
// workCleanup Thread to be scheduled again.
486+
for (auto it = workList_.begin(); it != workList_.end();
487+
/* no increment*/) {
488+
auto& work = *it;
489+
if (work->isCompleted()) {
490+
it = workList_.erase(it);
491+
} else {
492+
++it;
493+
}
468494
}
495+
// Wait for workList_ to become empty before proceeding with shutdown.
496+
workListCV_.wait(lock, [&]() -> bool { return workList_.empty(); });
497+
lock.unlock();
498+
workCleanupThread_.join();
469499
}
470-
// Wait for workList_ to become empty before proceeding with shutdown.
471-
workListCV_.wait(lock, [&]() -> bool { return workList_.empty(); });
472-
lock.unlock();
473500

474501
#ifdef ENABLE_NCCL_ERROR_CHECKING
475502
ncclCommWatchdogThread_.join();
@@ -486,7 +513,6 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
486513
}
487514
}
488515
}
489-
workCleanupThread_.join();
490516
}
491517

492518
void ProcessGroupNCCL::ncclCommWatchdog() {
@@ -542,7 +568,7 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() {
542568
}
543569
}
544570

545-
{
571+
if (asyncErrorHandling_) {
546572
std::unique_lock<std::mutex> lock(workListMutex_);
547573
for (auto& work : workList_) {
548574
work->checkAndSetException();
@@ -964,7 +990,9 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
964990
work->store_ = store_;
965991
}
966992

967-
workEnqueue(work);
993+
if (asyncErrorHandling_) {
994+
workEnqueue(work);
995+
}
968996

969997
return work;
970998
}

torch/lib/c10d/ProcessGroupNCCL.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ namespace c10d {
1919
// non-blocking.
2020
constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
2121

22+
// Environment variable which controls whether or not we perform Async Error
23+
// Handling with NCCL.
24+
constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING";
25+
2226
// ProcessGroupNCCL implements NCCL bindings for c10d.
2327
//
2428
// All functions of the class are expected to be called in the same order
@@ -490,6 +494,10 @@ class ProcessGroupNCCL : public ProcessGroup {
490494
// accordingly.
491495
void parseNcclBlockingWait();
492496

497+
// Reads the NCCL_ASYNC_ERROR_HANDLING environment variable and sets asyncErrorHandling_
498+
// accordingly.
499+
void parseNcclAsyncErrorHandling();
500+
493501
void workCleanupLoop();
494502

495503
protected:
@@ -594,6 +602,10 @@ class ProcessGroupNCCL : public ProcessGroup {
594602
// for the operation to complete.
595603
bool blockingWait_ = false;
596604

605+
// Whether ot not the workCleanupThread is used to perform async error
606+
// handling.
607+
bool asyncErrorHandling_ = false;
608+
597609
// Timeout for operations. This is only used when blockingWait_ is enabled.
598610
std::chrono::milliseconds opTimeout_;
599611

0 commit comments

Comments
 (0)