Skip to content

Conversation

@osalpekar
Copy link
Member

@osalpekar osalpekar commented Jul 7, 2020

Stack from ghstack:

This Commit:
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

This Stack:
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: D21938498

Creating host guard function that throws exception in error cases and registering guard function as a callback on current CUDA stream

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Jul 7, 2020

💊 CI failures summary and remediations

As of commit 019bf7b (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 08 22:30:43 what(): NCCL error: unhandled system error, NCCL version 2.7.6
Sep 08 22:30:36   test_scatter_basics_cuda (__main__.ProcessGroupGlooTest) ... ok (1.926s) 
Sep 08 22:30:36   test_scatter_checks (__main__.ProcessGroupGlooTest) ... ok (0.127s) 
Sep 08 22:30:37   test_scatter_stress (__main__.ProcessGroupGlooTest) ... ok (0.623s) 
Sep 08 22:30:37   test_scatter_stress_cuda (__main__.ProcessGroupGlooTest) ... skip (0.001s) 
Sep 08 22:30:37   test_send_recv_all_to_all (__main__.ProcessGroupGlooTest) ... ok (0.124s) 
Sep 08 22:30:38   test_sparse_allreduce_basics (__main__.ProcessGroupGlooTest) ... ok (0.623s) 
Sep 08 22:30:40   test_sparse_allreduce_basics_cuda (__main__.ProcessGroupGlooTest) ... ok (2.527s) 
Sep 08 22:30:41   test_sparse_allreduce_checks (__main__.ProcessGroupGlooTest) ... ok (0.125s) 
Sep 08 22:30:43   test_allgather_ops (__main__.ProcessGroupNCCLTest) ... ok (2.640s) 
Sep 08 22:30:43   test_allreduce_ops (__main__.ProcessGroupNCCLTest) ... terminate called after throwing an instance of 'std::runtime_error' 
Sep 08 22:30:43   what():  NCCL error: unhandled system error, NCCL version 2.7.6 
Sep 08 22:30:43 Traceback (most recent call last): 
Sep 08 22:30:43   File "test/run_test.py", line 735, in <module> 
Sep 08 22:30:43     main() 
Sep 08 22:30:43   File "test/run_test.py", line 718, in main 
Sep 08 22:30:43     raise RuntimeError(err_message) 
Sep 08 22:30:43 RuntimeError: distributed/test_c10d failed! Received signal: SIGIOT 
Sep 08 22:30:44 + cleanup 
Sep 08 22:30:44 + retcode=1 
Sep 08 22:30:44 + set +x 
Sep 08 22:30:44 =================== sccache compilation log =================== 

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 59 times.

osalpekar added 2 commits July 6, 2020 17:40
…onize"

Creating host guard function that throws exception in error cases and registering guard function as a callback on current CUDA stream

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
…onize"

Creating host guard function that throws exception in error cases and registering guard function as a callback on current CUDA stream

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
…onize"

Creating host guard function that throws exception in error cases and registering guard function as a callback on current CUDA stream

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
@osalpekar osalpekar changed the title [NCCL] Register Guard Function Callback in WorkNCCL Synchronize [NCCL] Use cudaEventQuery to Poll for GPU operation errors Aug 19, 2020
**This Commit:**
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
**This Commit:**
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
**This Commit:**
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
**This Commit:**
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
**This Commit:**
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
**This Commit:**
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
**This Commit:**
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
**This Commit:**
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 4e5c55e.

@facebook-github-bot facebook-github-bot deleted the gh/osalpekar/54/head branch September 13, 2020 14:16
loadbxh pushed a commit to loadbxh/Torch that referenced this pull request Sep 23, 2020
Pull Request resolved: pytorch/pytorch#41051

**This Commit:**
In the workCleanupThread, we process completion and exception handling for workNCCL objects corresponding to collective calls that have either completed GPU Execution, or have already thrown an exception. This way, we throw an exception from the workCleanupThread for failed GPU operations. This approach replaces the previous (and lower performance) approach of enqueuing a callback on the CUDA stream to process failures.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

ghstack-source-id: 111301606

Differential Revision: [D21938498](https://our.internmc.facebook.com/intern/diff/D21938498/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants