Skip to content

Commit 1df24fd

Browse files
osalpekarfacebook-github-bot
authored andcommitted
[NCCL] Timeout Loop Thread for Async Error Handling (#41050)
Summary: Pull Request resolved: #41050 **This Commit:** We introduce a workVector to track live workNCCL objects corresponding to collective operations. Further, we introduce a workCleanupLoop, which busy-polls the vector of workNCCL objects and removes them upon completion. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. Test Plan: See D22054298 for verification of correctness and performance Reviewed By: jiayisuse Differential Revision: D21916637 fbshipit-source-id: f8cadaab0071aaad1c4e31f9b089aa23cba0cfbe
1 parent 15cbd1c commit 1df24fd

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

torch/lib/c10d/ProcessGroupNCCL.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ ncclResult_t ncclAlltoallv(
228228
} // namespace
229229

230230
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
231+
const int64_t ProcessGroupNCCL::kWorkCleanupThreadSleepMillis = 1000;
231232
constexpr int64_t kWaitForAbortCommStoreKey = 1000;
232233
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
233234
const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis = 10 * 1000;
@@ -399,7 +400,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
399400
: ProcessGroup(rank, size),
400401
store_(store),
401402
ncclCommCounter_(0),
402-
terminateWatchdog_(false),
403+
terminateProcessGroup_(false),
403404
opTimeout_(opTimeout) {
404405
try {
405406
parseNcclBlockingWait();
@@ -424,11 +425,14 @@ ProcessGroupNCCL::ProcessGroupNCCL(
424425
ncclCommWatchdogThread_ =
425426
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
426427
#endif
428+
429+
workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this);
427430
}
428431

429432
ProcessGroupNCCL::~ProcessGroupNCCL() {
430-
terminateWatchdog_.store(true);
433+
terminateProcessGroup_.store(true);
431434
watchdogCV_.notify_one();
435+
workListCV_.notify_one();
432436
#ifdef ENABLE_NCCL_ERROR_CHECKING
433437
ncclCommWatchdogThread_.join();
434438
#endif
@@ -444,6 +448,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
444448
}
445449
}
446450
}
451+
workCleanupThread_.join();
447452
}
448453

449454
void ProcessGroupNCCL::ncclCommWatchdog() {
@@ -458,7 +463,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
458463
}
459464

460465
void ProcessGroupNCCL::ncclCommWatchdogInternal() {
461-
while (!terminateWatchdog_.load()) {
466+
while (!terminateProcessGroup_.load()) {
462467
std::unordered_set<std::string> abortedCommIds;
463468
std::unordered_set<std::string> allCommIds;
464469

@@ -554,7 +559,32 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() {
554559
watchdogCV_.wait_for(
555560
lock,
556561
std::chrono::milliseconds(kWatchdogThreadSleepMillis),
557-
[&]() -> bool { return terminateWatchdog_.load(); });
562+
[&]() -> bool { return terminateProcessGroup_.load(); });
563+
}
564+
}
565+
566+
void ProcessGroupNCCL::workCleanupLoop() {
567+
while (!terminateProcessGroup_.load()) {
568+
std::unique_lock<std::mutex> lock(workListMutex_);
569+
// We busy-poll the work vector every kWatchdogThreadSleepMillis
570+
// milliseconds as long as the atomic is True.
571+
workListCV_.wait_for(
572+
lock,
573+
std::chrono::milliseconds(kWorkCleanupThreadSleepMillis),
574+
[&]() -> bool { return terminateProcessGroup_.load(); });
575+
576+
for (auto it = workList_.begin(); it != workList_.end();
577+
/* no increment*/) {
578+
auto& work = *it;
579+
if (work->isCompleted()) {
580+
// Remove all Completed WorkNCCL Objects from the Vector
581+
it = workList_.erase(it);
582+
} else {
583+
// Increment the iterator if the current WorkNCCL object is not
584+
// completed.
585+
++it;
586+
}
587+
}
558588
}
559589
}
560590

@@ -797,6 +827,14 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
797827
futureNCCLCallbackStreams_[deviceIndex]);
798828
}
799829

830+
void ProcessGroupNCCL::workEnqueue(
831+
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> work) {
832+
if (!terminateProcessGroup_.load()) {
833+
std::lock_guard<std::mutex> lock(workListMutex_);
834+
workList_.emplace_back(std::move(work));
835+
}
836+
}
837+
800838
template <typename Fn, typename PreProcess, typename PostProcess>
801839
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
802840
std::vector<at::Tensor>& inputs,
@@ -861,6 +899,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
861899
work->store_ = store_;
862900
}
863901

902+
workEnqueue(work);
903+
864904
return work;
865905
}
866906

torch/lib/c10d/ProcessGroupNCCL.hpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <list>
34
#include <mutex>
45
#include <thread>
56
#include <unordered_map>
@@ -478,8 +479,11 @@ class ProcessGroupNCCL : public ProcessGroup {
478479
// accordingly.
479480
void parseNcclBlockingWait();
480481

482+
void workCleanupLoop();
483+
481484
protected:
482485
static const int64_t kWatchdogThreadSleepMillis;
486+
static const int64_t kWorkCleanupThreadSleepMillis;
483487

484488
// The store is used to broadcast the NCCL unique ID of rank 0.
485489
std::shared_ptr<Store> store_;
@@ -521,15 +525,30 @@ class ProcessGroupNCCL : public ProcessGroup {
521525
// Watchdog thread which looks for errors on the cached NCCL communicators.
522526
std::thread ncclCommWatchdogThread_;
523527

524-
// Whether or not we should terminate the watchdog thread.
525-
std::atomic<bool> terminateWatchdog_;
528+
// Whether or not we should terminate the watchdog and workCleanup threads.
529+
std::atomic<bool> terminateProcessGroup_;
526530

527531
// Condition variable to control how long the watchdog thread waits.
528532
std::condition_variable watchdogCV_;
529533

530534
// Mutex for watchdog.
531535
std::mutex watchdogCVMutex_;
532536

537+
// Thread that removes NCCL Work upon timeout
538+
std::thread workCleanupThread_;
539+
540+
// Mutex to Guard workList_
541+
std::mutex workListMutex_;
542+
543+
// Condition Variable for timeout thread sleep
544+
std::condition_variable workListCV_;
545+
546+
// Vector to Store WorkNCCL pointers
547+
std::list<std::shared_ptr<ProcessGroupNCCL::WorkNCCL>> workList_;
548+
549+
// Add Work Pointer to workVector
550+
void workEnqueue(std::shared_ptr<ProcessGroupNCCL::WorkNCCL>);
551+
533552
// The CUDA steams used by NCCL kernels
534553
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
535554
ncclStreams_;

0 commit comments

Comments
 (0)