Skip to content

Conversation

@osalpekar
Copy link
Member

@osalpekar osalpekar commented Jul 7, 2020

Stack from ghstack:

This Commit:
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

This Stack:
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: D21943151

…hread

Watchdog Thread checks for error-ed or timed out `WorkNCCL` objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Jul 7, 2020

💊 CI failures summary and remediations

As of commit 6c816ed (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 08 22:32:09 what(): NCCL error: unhandled system error, NCCL version 2.7.6
Sep 08 22:32:02   test_scatter_basics_cuda (__main__.ProcessGroupGlooTest) ... ok (2.829s) 
Sep 08 22:32:03   test_scatter_checks (__main__.ProcessGroupGlooTest) ... ok (0.128s) 
Sep 08 22:32:03   test_scatter_stress (__main__.ProcessGroupGlooTest) ... ok (0.524s) 
Sep 08 22:32:03   test_scatter_stress_cuda (__main__.ProcessGroupGlooTest) ... skip (0.001s) 
Sep 08 22:32:03   test_send_recv_all_to_all (__main__.ProcessGroupGlooTest) ... ok (0.125s) 
Sep 08 22:32:04   test_sparse_allreduce_basics (__main__.ProcessGroupGlooTest) ... ok (0.624s) 
Sep 08 22:32:07   test_sparse_allreduce_basics_cuda (__main__.ProcessGroupGlooTest) ... ok (2.830s) 
Sep 08 22:32:07   test_sparse_allreduce_checks (__main__.ProcessGroupGlooTest) ... ok (0.126s) 
Sep 08 22:32:09   test_allgather_ops (__main__.ProcessGroupNCCLTest) ... ok (2.563s) 
Sep 08 22:32:09   test_allreduce_ops (__main__.ProcessGroupNCCLTest) ... terminate called after throwing an instance of 'std::runtime_error' 
Sep 08 22:32:09   what():  NCCL error: unhandled system error, NCCL version 2.7.6 
Sep 08 22:32:10 Traceback (most recent call last): 
Sep 08 22:32:10   File "test/run_test.py", line 735, in <module> 
Sep 08 22:32:10     main() 
Sep 08 22:32:10   File "test/run_test.py", line 718, in main 
Sep 08 22:32:10     raise RuntimeError(err_message) 
Sep 08 22:32:10 RuntimeError: distributed/test_c10d failed! Received signal: SIGIOT 
Sep 08 22:32:10 + cleanup 
Sep 08 22:32:10 + retcode=1 
Sep 08 22:32:10 + set +x 
Sep 08 22:32:10 =================== sccache compilation log =================== 

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 59 times.

osalpekar added 2 commits July 6, 2020 17:40
… Watchdog Thread"

Watchdog Thread checks for error-ed or timed out `WorkNCCL` objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"

Watchdog Thread checks for error-ed or timed out `WorkNCCL` objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"

Watchdog Thread checks for error-ed or timed out `WorkNCCL` objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"


**This Commit:**
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"


**This Commit:**
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"


**This Commit:**
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"


**This Commit:**
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
Comment on lines 488 to 494
// We should not abort the communicators if we are performing a
// non-blocking wait(). The reason for this is that if we abort the
// nccl communicator, wait() might not throw exceptions and
// subsequent operations might run on garbage results.
// The current model is that when we call wait(), subsequent
// operations only run after this work is done or we hang forever
// waiting for the operation to complete.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we violate the contract mentioned here if we remove blockingWait_ here?

Copy link
Contributor

@jiayisuse jiayisuse Sep 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I guess users will get the same behavior when blockingWait_ is true. We may need to revise this block of comment, saying that the aborted nccl call will be caught by cleanup thread and cause exception.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If NCCL_BLOCKING_WAIT is false and NCCL_ASYNC_ERROR_HANDLING is false, we would still end up aborting communicators here that might cause consistency issues where other ops after the aborted collective might run on corrupted data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch - thanks. We should probably guard this code block with if (blockingWait_ || asyncErrorHandling_) to handle this case.

Comment on lines +512 to 534
{
std::unique_lock<std::mutex> lock(workListMutex_);
for (auto& work : workList_) {
work->checkAndSetException();
// Aborting NCCL Communicators due to errors is already handled above.
if (work->exception()) {
continue;
}

// Check for Timeouts in the WorkNCCL Operations, and abort all
// communicators accordingly.
auto currentTimepoint = std::chrono::steady_clock::now();
if (std::chrono::duration_cast<std::chrono::milliseconds>(
currentTimepoint - work->workStartTime_) > work->opTimeout_) {
std::exception_ptr exception_ptr = std::make_exception_ptr(
std::runtime_error("NCCL Operation Timed Out"));
work->setException(exception_ptr);
for (const auto& ncclComm : work->ncclComms_) {
ncclComm->ncclCommAbort();
abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId()));
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is becoming pretty large with a multiple complex blocks, can we move each block out into separate helper functions for more clarity?

… Watchdog Thread"


**This Commit:**
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"


**This Commit:**
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"


**This Commit:**
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"


**This Commit:**
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
… Watchdog Thread"


**This Commit:**
Watchdog Thread checks for error-ed or timed out WorkNCCL objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in f8f7b78.

@facebook-github-bot facebook-github-bot deleted the gh/osalpekar/55/head branch September 13, 2020 14:17
loadbxh pushed a commit to loadbxh/Torch that referenced this pull request Sep 23, 2020
…hread

Pull Request resolved: pytorch/pytorch#41052

**This Commit:**
Watchdog Thread checks for error-ed or timed out `WorkNCCL` objects and aborts all associated NCCL Communicators. For now, we  also process these aborted communicators as with the existing Watchdog logic (by adding them to abortedCommIds and writing aborted communicator ids to the store.)

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

ghstack-source-id: 111311021

Differential Revision: [D21943151](https://our.internmc.facebook.com/intern/diff/D21943151/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants