Skip to content

Conversation

@osalpekar
Copy link
Member

@osalpekar osalpekar commented Jul 7, 2020

Stack from ghstack:

This Commit:
Some minor refactoring - added helper to check if WorkNCCL objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses lock_guard. Also renaming the timeoutCVMutex mutex to be more descriptive.

This Stack:
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: D21943520

Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Jul 7, 2020

💊 CI failures summary and remediations

As of commit 1f38e08 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 08 22:34:24 what(): NCCL error: unhandled system error, NCCL version 2.7.6
Sep 08 22:34:17   test_scatter_basics_cuda (__main__.ProcessGroupGlooTest) ... ok (2.128s) 
Sep 08 22:34:17   test_scatter_checks (__main__.ProcessGroupGlooTest) ... ok (0.128s) 
Sep 08 22:34:18   test_scatter_stress (__main__.ProcessGroupGlooTest) ... ok (0.825s) 
Sep 08 22:34:18   test_scatter_stress_cuda (__main__.ProcessGroupGlooTest) ... skip (0.001s) 
Sep 08 22:34:18   test_send_recv_all_to_all (__main__.ProcessGroupGlooTest) ... ok (0.125s) 
Sep 08 22:34:18   test_sparse_allreduce_basics (__main__.ProcessGroupGlooTest) ... ok (0.624s) 
Sep 08 22:34:21   test_sparse_allreduce_basics_cuda (__main__.ProcessGroupGlooTest) ... ok (2.429s) 
Sep 08 22:34:21   test_sparse_allreduce_checks (__main__.ProcessGroupGlooTest) ... ok (0.127s) 
Sep 08 22:34:24   test_allgather_ops (__main__.ProcessGroupNCCLTest) ... ok (2.638s) 
Sep 08 22:34:24   test_allreduce_ops (__main__.ProcessGroupNCCLTest) ... terminate called after throwing an instance of 'std::runtime_error' 
Sep 08 22:34:24   what():  NCCL error: unhandled system error, NCCL version 2.7.6 
Sep 08 22:34:24 Traceback (most recent call last): 
Sep 08 22:34:24   File "test/run_test.py", line 735, in <module> 
Sep 08 22:34:24     main() 
Sep 08 22:34:24   File "test/run_test.py", line 718, in main 
Sep 08 22:34:24     raise RuntimeError(err_message) 
Sep 08 22:34:24 RuntimeError: distributed/test_c10d failed! Received signal: SIGIOT 
Sep 08 22:34:25 + cleanup 
Sep 08 22:34:25 + retcode=1 
Sep 08 22:34:25 + set +x 
Sep 08 22:34:25 =================== sccache compilation log =================== 

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 64 times.

osalpekar added 2 commits July 6, 2020 17:40
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in afbf2f1.

@facebook-github-bot facebook-github-bot deleted the gh/osalpekar/56/head branch September 13, 2020 14:17
loadbxh pushed a commit to loadbxh/Torch that referenced this pull request Sep 23, 2020
Pull Request resolved: pytorch/pytorch#41053

**This Commit:**
Some minor refactoring - added helper to check if `WorkNCCL` objects have timed out. Adding a new finish function to ProcessGroupNCCL::WorkNCCL that avoids notifying CV and uses `lock_guard`. Also renaming the timeoutCVMutex mutex to be more descriptive.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

ghstack-source-id: 111311018

Differential Revision: [D21943520](https://our.internmc.facebook.com/intern/diff/D21943520/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants