Skip to content

Conversation

@jeffdaily
Copy link
Collaborator

There are two mutexes within CUDACachingAllocator that cause a deadlock. One of the mutexes was added in order to work around the issue of NCCL interacting poorly with cudaFree. See

As of NCCL version 2 and its new group start/end APIs, the protection surrounding cudaFree() is no longer needed. The PyTorch code was updated to use the NCCL2 group start/end API, but the corresponding cuda_free_mutex and its getter getFreeMutex() were not revised. This PR removes the use of the getFreeMutex() when NCCL2 is used by moving calls to getFreeMutex() into the AutoNcclGroup. That way, depending on the NCCL version used, we either use the mutex or we use the new group APIs.

The race condition is as follows, thanks to @skeelyamd:

The deadlock occurs between hip_free_mutex (aka cuda_free_mutex in github) (https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDACachingAllocator.cpp#L165) and mutex (https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDACachingAllocator.cpp#L162).

hip_free_mutex is exported from THCCachingAllocator in getFreeMutex (https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDACachingAllocator.cpp#L660) and is acquired in ProcessGroupNCCL::collective (https://github.com/pytorch/pytorch/blob/master/torch/lib/c10d/ProcessGroupNCCL.cpp#L397), which then calls back into THCCachingAllocator via c10::cuda::CUDACachingAllocator::recordStream (https://github.com/pytorch/pytorch/blob/master/torch/lib/c10d/ProcessGroupNCCL.cpp#L416 to https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDACachingAllocator.cpp#L655 to https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDACachingAllocator.cpp#L379). At this point it acquires mutex (https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDACachingAllocator.cpp#L384).

This requires hip_free_mutex to be locked before mutex.

However, in free_blocks (https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDACachingAllocator.cpp#L505) THCCachingAllocator locks hip_free_mutex. Free_blocks is called from emptyCache (https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDACachingAllocator.cpp#L328) which locks mutex.

That requires mutex to be locked before hip_free_mutex.

emptyCache and ProcessGroupNCCL::collective may not be executed concurrently but this is occurring and deadlocking the CPU.

free_blocks is also called by malloc (via cuda_malloc_retry -> free_cached_blocks -> free_blocks) which also locks mutex first and so malloc must not execute concurrent with ProcessGroupNCCL::collective.

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general oncall: distributed Add this issue/PR to distributed oncall triage queue module: nccl Problems related to nccl support module: pybind Related to our Python bindings / interactions with other Python libraries labels Jun 25, 2019
@soumith soumith added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 25, 2019
@pietern
Copy link
Contributor

pietern commented Jun 25, 2019

Does calling ncclGroupEnd imply that kernels are launched on multiple devices atomically? If so, we can remove the mutex. If not, we have to keep it around. If cudaFree races with the NCCL internals that are responsible for launching kernels on multiple devices, we reintroduce the potential deadlock.

@jeffdaily
Copy link
Collaborator Author

I believe this PR is related to #14870.

NCCL 2.x documentation concerning Group Calls: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/groups.html

From the NCCL sources, in ncclGroupEnd:

  /* Collectives are done in three steps :
   * 1. Barrier Check In. Only the last call may call cudaLaunchKernel[cooperative]
   * 2. Barrier Wait. No CUDA call is permitted
   * 3. Enqueue Events. CUDA event wait/enqueue.
   * This is needed because step 2 cannot call any CUDA primitive, otherwise if
   * cudaFree happens between 1 and 3, it could block that CUDA call and
   * prevent some ranks from launching their network threads, which would
   * prevent the NCCL call from completing, blocking the cudaFree call.
   */

In light of the above, it's not clear to me whether NCCL 2 fixed the race with cudaFree or not. However, the NCCL 2 code guarantees that only the last CPU thread to arrive at the collective for a given communicator will perform the launching for all devices. For sufficiently new (CUDART_VERSION >= 9000) and capable hardware, it will use cudaCooperativeLaunchMultiDevice.

An alternative to this PR would be to remove the cuda_free_mutex and replace it with the recursive_mutex in the same file. This would also prevent deadlock, but without removing the cudaFree race protection. But it is not clear what the performance impact would be relative to just removing the cuda_free_mutex when NCCL v2 is in use.

@jeffdaily
Copy link
Collaborator Author

@pietern
Copy link
Contributor

pietern commented Jul 1, 2019

Thanks for the detailed analysis, @jeffdaily. I went and looked at the NCCL code and did some more investigation myself today and found the following:

  • Cooperative multi device launch is used if the CUDA version is >= 9.0 and the devices have cudaDevAttrCooperativeMultiDeviceLaunch set. I wrote a quick program to check this attribute on a few classes of GPUs and found this is supported since the Pascal generation. For older GPUs, it will loop over the participating devices and launch the kernels sequentially. We still support Kepler and Maxwell generation GPUs, so have to deal with the non-atomic multi device launch case.
  • The cudaFree race is not solved. If during sequential launch (either in non-atomic group mode or parallel launch mode) any process runs a cudaFree on another thread, it risks deadlock. If I recall correctly, calling cudaFree triggers device synchronization and waits for all kernels to exit. If any of these are NCCL kernels, and not all of them have been launched, they will hang.

Regarding the lock ordering, this is indeed an issue. I like the idea of using the AutoNCCLGroup here. While we can't remove the free lock, we can narrow its scope to avoid locking the main allocator lock while holding the free lock. Then we'd have 2 loops: one to launch the NCCL kernels, and one to record streams. They don't depend on each other so we can sequence them one after the other to fix this. This doesn't depend on AutoNCCLGroup of course and can be done independently.

@jeffdaily jeffdaily force-pushed the remove_cuda_free_mutex branch from 8fd31c7 to 42cf602 Compare July 2, 2019 20:44
@jeffdaily
Copy link
Collaborator Author

Hi @pietern . I have rebased the PR and revised it based on your comments. The free lock remains and its lifetime is managed by the AutoNcclGroup. There are now two loops, one to record the streams and one to launch the NCCL kernels. The stream recording happens first and does so without holding the free lock. Then, the NCCL kernel launches are within an AutoNcclGroup scope block.

@pietern
Copy link
Contributor

pietern commented Jul 17, 2019

Thanks for updating the PR @jeffdaily. This looks good to me. @mrshenli can you take a look as well?

@mrshenli
Copy link
Contributor

@pietern I am looking now

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM too!


namespace {

struct AutoNcclGroup {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question, is this the same as the one in nccl.h except the error checking? Is it possible to directly use that one here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the code here is identical to the one in nccl.h except the error checking macro. This 'new' AutoNcclGroup is local to the ProcessGroupNCCL.cpp file. I chose to repeat the code here instead of trying to add the header dependency to torch/csrc/cuda when the code there is depending on c10 -- circular dep? I'm not yet completely familiar with how the PyTorch directories and libraries are structured. Let me know how I should proceed. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks! Then, let's keep it as is.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pietern is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@pietern
Copy link
Contributor

pietern commented Jul 18, 2019

Thanks @jeffdaily and @mrshenli!

@facebook-github-bot
Copy link
Contributor

@pietern merged this pull request in 29347cc.

@yf225
Copy link
Contributor

yf225 commented Jul 18, 2019

This PR breaks master CUDA builds with:

Jul 18 14:59:18 In file included from /var/lib/jenkins/workspace/torch/lib/c10d/../c10d/ProcessGroupNCCL.hpp:6:0,
Jul 18 14:59:18 from /var/lib/jenkins/workspace/torch/lib/c10d/ProcessGroupNCCL.cpp:1:
Jul 18 14:59:18 /var/lib/jenkins/workspace/torch/lib/c10d/ProcessGroupNCCL.cpp: In destructor 'c10d::{anonymous}::AutoNcclGroup::~AutoNcclGroup()':
Jul 18 14:59:18 /var/lib/jenkins/workspace/torch/lib/c10d/../c10d/NCCLUtils.hpp:14:35: error: throw will always call terminate() [-Werror=terminate]
Jul 18 14:59:18 throw std::runtime_error(err);
Jul 18 14:59:18 ^
Jul 18 14:59:18 /var/lib/jenkins/workspace/torch/lib/c10d/ProcessGroupNCCL.cpp:27:5: note: in expansion of macro 'C10D_NCCL_CHECK'
Jul 18 14:59:18 C10D_NCCL_CHECK(ncclGroupEnd());
Jul 18 14:59:18 ^
Jul 18 14:59:18 /var/lib/jenkins/workspace/torch/lib/c10d/../c10d/NCCLUtils.hpp:14:35: note: in C++11 destructors default to noexcept
Jul 18 14:59:18 throw std::runtime_error(err);
Jul 18 14:59:18 ^
Jul 18 14:59:18 /var/lib/jenkins/workspace/torch/lib/c10d/ProcessGroupNCCL.cpp:27:5: note: in expansion of macro 'C10D_NCCL_CHECK'
Jul 18 14:59:18 C10D_NCCL_CHECK(ncclGroupEnd());
Jul 18 14:59:18 ^
Jul 18 14:59:18 cc1plus: all warnings being treated as errors
Jul 18 14:59:18 caffe2/lib_c10d/CMakeFiles/c10d.dir/build.make:206: recipe for target 'caffe2/lib_c10d/CMakeFiles/c10d.dir/ProcessGroupNCCL.cpp.o' failed
Jul 18 14:59:18 make[2]: *** [caffe2/lib_c10d/CMakeFiles/c10d.dir/ProcessGroupNCCL.cpp.o] Error 1
Jul 18 14:59:18 CMakeFiles/Makefile2:11636: recipe for target 'caffe2/lib_c10d/CMakeFiles/c10d.dir/all' failed
Jul 18 14:59:18 make[1]: *** [caffe2/lib_c10d/CMakeFiles/c10d.dir/all] Error 2

I am reverting it.

@jeffdaily
Copy link
Collaborator Author

How did CI not catch this earlier? This AutoNcclGroup code is basically copied from torch/csrc/cuda/nccl.h, how does it work there?

@jeffdaily
Copy link
Collaborator Author

@yf225 can we fix this with a new PR rather than revert the entire PR?

@yf225
Copy link
Contributor

yf225 commented Jul 18, 2019

@jeffdaily If the fix PR takes longer, it would mean that all other PRs based on current master have their CUDA builds broken, which I don't think is a good idea.

Since most of this PR has been reviewed, opening a new PR with the fix added shouldn't take too long to review, and we can make sure CI is green before merging it.

@jeffdaily
Copy link
Collaborator Author

@yf225 ok to revert. I will start the new PR now, with the fix added.

@mrshenli
Copy link
Contributor

@ezyang do you know why our CI didn't catch this error?

@yf225
Copy link
Contributor

yf225 commented Jul 18, 2019

@mrshenli I suspect that it's because the last CI run was too old (16 days ago)

@jeffdaily
Copy link
Collaborator Author

@ezyang @mrshenli I have submitted PR #23040 to hopefully address the new CI failure.

facebook-github-bot pushed a commit that referenced this pull request Jul 22, 2019
Summary:
Revision of #22173 to address CI failure after merging.
Pull Request resolved: #23040

Differential Revision: D16366872

Pulled By: mrshenli

fbshipit-source-id: 747b6ecf2dc195c25f82b8f732ae9ff52cd3a394
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cuda Related to torch.cuda, and CUDA support in general module: nccl Problems related to nccl support module: pybind Related to our Python bindings / interactions with other Python libraries oncall: distributed Add this issue/PR to distributed oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants