Skip to content

Conversation

@osalpekar
Copy link
Member

@osalpekar osalpekar commented Jul 7, 2020

Stack from ghstack:

This Commit:
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

This Stack:
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: D22054298

We should block until all WorkNCCL objects have been either aborted or completed and removed from the work vector.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Jul 7, 2020

💊 CI failures summary and remediations

As of commit f19707c (more details on the Dr. CI page):



❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️

Sep 08 22:37:06 RuntimeError: Process 0 terminated or timed out after 100.08557462692261 seconds
Sep 08 22:37:06 ====================================================================== 
Sep 08 22:37:06 ERROR [100.108s]: test_failure_recovery (__main__.DistributedDataParallelTest) 
Sep 08 22:37:06 ---------------------------------------------------------------------- 
Sep 08 22:37:06 Traceback (most recent call last): 
Sep 08 22:37:06   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 224, in wrapper 
Sep 08 22:37:06     self._join_processes(fn) 
Sep 08 22:37:06   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 337, in _join_processes 
Sep 08 22:37:06     self._check_return_codes(elapsed_time) 
Sep 08 22:37:06   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 375, in _check_return_codes 
Sep 08 22:37:06     raise RuntimeError('Process {} terminated or timed out after {} seconds'.format(i, elapsed_time)) 
Sep 08 22:37:06 RuntimeError: Process 0 terminated or timed out after 100.08557462692261 seconds 
Sep 08 22:37:06  
Sep 08 22:37:06 ---------------------------------------------------------------------- 
Sep 08 22:37:06 Ran 120 tests in 326.287s 
Sep 08 22:37:06  
Sep 08 22:37:06 FAILED (errors=2, skipped=9) 
Sep 08 22:37:06  
Sep 08 22:37:06 Generating XML reports... 
Sep 08 22:37:06 Generated XML report: test-reports/python-unittest/TEST-CommTest-20200908223140.xml 
Sep 08 22:37:06 Generated XML report: test-reports/python-unittest/TEST-ComputeBucketAssignmentTest-20200908223140.xml 
Sep 08 22:37:06 Generated XML report: test-reports/python-unittest/TEST-DistributedDataParallelTest-20200908223140.xml 

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 62 times.

We should block until all WorkNCCL objects have been either aborted or completed and removed from the work vector.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
osalpekar added a commit that referenced this pull request Jul 7, 2020
Pull Request resolved: #41054

We should block until all WorkNCCL objects have been either aborted or completed and removed from the work vector.
ghstack-source-id: 107224185

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)
We should block until all WorkNCCL objects have been either aborted or completed and removed from the work vector.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
We should block until all WorkNCCL objects have been either aborted or completed and removed from the work vector.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
Comment on lines 644 to 648
if (workList_.empty()) {
// Notify the main thread if it is blocked in the shutdown sequence,
// waiting for the work vector to become empty.
lock.unlock();
workVectorCV_.notify_one();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to do this? Wouldn't this automatically abort when terminateProcessGroup_ is set to True? Or are we referring to some other thread here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This notifies the CV in the destructor that is waiting for the workList_ to become empty.

Comment on lines 451 to 462
std::unique_lock<std::mutex> lock(workListMutex_);
// Clean up any remaining items in the workList_ instead of waiting for the
// workCleanup Thread to be scheduled again.
for (auto it = workList_.begin(); it != workList_.end();
/* no increment*/) {
auto& work = *it;
if (work->isCompleted()) {
it = workList_.erase(it);
} else {
++it;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to perform this explicit cleanup? Once the destructor completes, wouldn't workList_ automatically be freed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right after this code block, we are blocking in the destructor until the workList_ is empty (no unfinished collectives left). Ideally this cleanup would just be done in the workcleanup thread itself, but there was one corner case causing an issue here - Hongyi and I are continuing to investigate, and I'll create an issue regarding this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to block in the destructor until workList_ is empty? How does removing completed items from workList_ help in the shutdown here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are leftover WorkNCCL objects in workList_, this means there are incomplete collectives. So we block on the workList_ becoming empty to ensure that all collectives have either been completed or errored out before we destruct ProcessGroupNCCL. Ideally, the workCleanupThread will just do all of the cleanup. However, when models contain a SyncBatchNorm layer, we find that this cleanup had to occur in the destructor. Hongyi (@jiayisuse ) and I have investigated this quite a bit, and I've created a follow-up issue (#44403). We should be able to deduplicate that explicit cleanup in the destructor and let the workCleanupThread handle it completely, and I'll continue to push this as a Better Engineering task.

**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
Comment on lines +651 to +656
if (workList_.empty()) {
// Notify the main thread if it is blocked in the shutdown sequence,
// waiting for the work vector to become empty.
lock.unlock();
workListCV_.notify_one();
}
Copy link
Contributor

@jiayisuse jiayisuse Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed

Comment on lines +470 to +472
// Wait for workList_ to become empty before proceeding with shutdown.
workListCV_.wait(lock, [&]() -> bool { return workList_.empty(); });
lock.unlock();
Copy link
Contributor

@jiayisuse jiayisuse Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked again, above code just removes completed work. So I guess we let cleanup thread to remove the unfinished works?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're still blocking to ensure the workList is empty, so the workCleanupThread will continue looping and removing works when they are completed.

**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 211ece7.

@facebook-github-bot facebook-github-bot deleted the gh/osalpekar/57/head branch September 13, 2020 14:17
loadbxh pushed a commit to loadbxh/Torch that referenced this pull request Sep 23, 2020
Pull Request resolved: pytorch/pytorch#41054

**This Commit:**
ProcessGroupNCCL destructor now blocks until all WorkNCCL objects have either been aborted or completed and removed from the work vector.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

ghstack-source-id: 111311019

Differential Revision: [D22054298](https://our.internmc.facebook.com/intern/diff/D22054298/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants