Skip to content

Commit 7bca02b

Browse files
committed
[NCCL] Add Environment Variable to guard Async Error Handling feature
In this PR, we introduce a new environment variable (NCCL_ASYNC_ERROR_HANDLING), which guards the asynchronous error handling feature. We intend to eventually turn this feature on by default for all users, but this is a temporary solution so the change in behavior from hanging to crashing is not the default for users all of a sudden. Differential Revision: [D23517895](https://our.internmc.facebook.com/intern/diff/D23517895/) ghstack-source-id: 111402543 Pull Request resolved: #44163
1 parent b4b62a7 commit 7bca02b

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

torch/lib/c10d/ProcessGroupNCCL.cpp

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,20 @@ void ProcessGroupNCCL::parseNcclBlockingWait() {
404404
}
405405
}
406406

407+
void ProcessGroupNCCL::parseNcclAsyncErrorHandling() {
408+
char* errorHandle = getenv(NCCL_ASYNC_ERROR_HANDLING);
409+
if (errorHandle != nullptr) {
410+
auto val = std::stoi(errorHandle);
411+
if (val == 1) {
412+
asyncErrorHandling_ = true;
413+
} else if (val != 0) {
414+
throw std::runtime_error(
415+
"Invalid value for environment variable: " +
416+
std::string(NCCL_ASYNC_ERROR_HANDLING));
417+
}
418+
}
419+
}
420+
407421
bool ProcessGroupNCCL::WorkNCCL::timedOut() {
408422
auto currentTimepoint = std::chrono::steady_clock::now();
409423
return (
@@ -428,6 +442,13 @@ ProcessGroupNCCL::ProcessGroupNCCL(
428442
"Invalid value for environment variable: " +
429443
std::string(NCCL_BLOCKING_WAIT));
430444
}
445+
try {
446+
parseNcclAsyncErrorHandling();
447+
} catch (std::exception& e) {
448+
throw std::runtime_error(
449+
"Invalid value for environment variable: " +
450+
std::string(NCCL_ASYNC_ERROR_HANDLING));
451+
}
431452

432453
// If single-process single-device mode, WorkNCCL::getFuture is supported,
433454
// so get a dedicated stream for each device to run FutureNCCL then callbacks.
@@ -445,31 +466,36 @@ ProcessGroupNCCL::ProcessGroupNCCL(
445466
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
446467
#endif
447468

448-
workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this);
469+
if (asyncErrorHandling_) {
470+
workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this);
471+
}
449472
}
450473

451474
ProcessGroupNCCL::~ProcessGroupNCCL() {
452475
terminateProcessGroup_.store(true);
453476
watchdogCV_.notify_one();
454477
workListCV_.notify_one();
455478

456-
std::unique_lock<std::mutex> lock(workListMutex_);
457-
// TODO: We can potentially merge this functionality into the workCleanup
458-
// thread or just allow the destructor to free workList_.
459-
// Clean up any remaining items in the workList_ instead of waiting for the
460-
// workCleanup Thread to be scheduled again.
461-
for (auto it = workList_.begin(); it != workList_.end();
462-
/* no increment*/) {
463-
auto& work = *it;
464-
if (work->isCompleted()) {
465-
it = workList_.erase(it);
466-
} else {
467-
++it;
479+
if (asyncErrorHandling_) {
480+
std::unique_lock<std::mutex> lock(workListMutex_);
481+
// TODO: We can potentially merge this functionality into the workCleanup
482+
// thread or just allow the destructor to free workList_.
483+
// Clean up any remaining items in the workList_ instead of waiting for the
484+
// workCleanup Thread to be scheduled again.
485+
for (auto it = workList_.begin(); it != workList_.end();
486+
/* no increment*/) {
487+
auto& work = *it;
488+
if (work->isCompleted()) {
489+
it = workList_.erase(it);
490+
} else {
491+
++it;
492+
}
468493
}
494+
// Wait for workList_ to become empty before proceeding with shutdown.
495+
workListCV_.wait(lock, [&]() -> bool { return workList_.empty(); });
496+
lock.unlock();
497+
workCleanupThread_.join();
469498
}
470-
// Wait for workList_ to become empty before proceeding with shutdown.
471-
workListCV_.wait(lock, [&]() -> bool { return workList_.empty(); });
472-
lock.unlock();
473499

474500
#ifdef ENABLE_NCCL_ERROR_CHECKING
475501
ncclCommWatchdogThread_.join();
@@ -486,7 +512,6 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
486512
}
487513
}
488514
}
489-
workCleanupThread_.join();
490515
}
491516

492517
void ProcessGroupNCCL::ncclCommWatchdog() {
@@ -542,7 +567,7 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() {
542567
}
543568
}
544569

545-
{
570+
if (asyncErrorHandling_) {
546571
std::unique_lock<std::mutex> lock(workListMutex_);
547572
for (auto& work : workList_) {
548573
work->checkAndSetException();
@@ -964,7 +989,9 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
964989
work->store_ = store_;
965990
}
966991

967-
workEnqueue(work);
992+
if (asyncErrorHandling_) {
993+
workEnqueue(work);
994+
}
968995

969996
return work;
970997
}

torch/lib/c10d/ProcessGroupNCCL.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ namespace c10d {
1919
// non-blocking.
2020
constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
2121

22+
// Environment variable which controls whether or not we perform Async Error
23+
// Handling with NCCL.
24+
constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING";
25+
2226
// ProcessGroupNCCL implements NCCL bindings for c10d.
2327
//
2428
// All functions of the class are expected to be called in the same order
@@ -490,6 +494,10 @@ class ProcessGroupNCCL : public ProcessGroup {
490494
// accordingly.
491495
void parseNcclBlockingWait();
492496

497+
// Reads the NCCL_ASYNC_ERROR_HANDLING environment variable and sets asyncErrorHandling_
498+
// accordingly.
499+
void parseNcclAsyncErrorHandling();
500+
493501
void workCleanupLoop();
494502

495503
protected:
@@ -594,6 +602,10 @@ class ProcessGroupNCCL : public ProcessGroup {
594602
// for the operation to complete.
595603
bool blockingWait_ = false;
596604

605+
// Whether ot not the workCleanupThread is used to perform async error
606+
// handling.
607+
bool asyncErrorHandling_ = false;
608+
597609
// Timeout for operations. This is only used when blockingWait_ is enabled.
598610
std::chrono::milliseconds opTimeout_;
599611

0 commit comments

Comments
 (0)