-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[NCCL] Add experimental Nonblocking NCCL Fault Tolerance/Checking #95715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95715
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 FailuresAs of commit 8b5c091: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1ccaf67 to
81ce78a
Compare
| ): ... | ||
| @staticmethod | ||
| def _group_start() -> None: ... | ||
| @staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing static method as _group_end() might need to check the communicator map of the ProcessGroup to properly wait on collectives if nonblocking is used.
torch/csrc/cuda/nccl.cpp
Outdated
| if (!comm_nonblocking) { | ||
| NCCL_CHECK(ncclCommCount(comm, &numranks)); | ||
| } else { | ||
| NCCL_CHECK_NONBLOCKING(ncclCommCount(comm, &numranks), _comm); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be unnecessary to also do a non-blocking check for ncclCommCount (unsure if there exists documentations on exactly which API calls might leave a communicator in an in-progress state).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree it is unnecessary. It is user responsibility to make sure comm is ready before accessing any attribute of it. (If comm is not ready, this call would actually error out rather than returning ncclInProgess.)
|
|
||
| #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ | ||
| (NCCL_MINOR >= 14) | ||
| #define NCCL_HAS_COMM_NONBLOCKING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why this needs to be redefined here in order to work when a definition already exists in ProcessGroupNCCL.hpp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because NCCLUtils.hpp do not include ProcessGroupNCCL.hpp and these two are not in the same compilation unit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: reminder for me to clean it.
|
@pytorchmergebot rebase |
|
@pytorchbot successfully started a rebase job. Check the current status here |
|
Successfully rebased |
84c038a to
82f26eb
Compare
| #ifdef NCCL_HAS_COMM_NONBLOCKING | ||
| ncclResult_t result = to_nccl_result(status); | ||
| while (result == ncclInProgress) { | ||
| ncclCommGetAsyncError(to_nccl_comm(comm), &result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if to_nccl_comm(comm) is needed here.
Here is definition of to_nccl_comm:
ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) {
return reinterpret_cast<ncclComm_t>(var);
}
It seems to me comm is already a ncclComm_t (the one defined by NCCL).
Side note:
We should remove the duplicated ncclComm_t definition in torch::cuda::nccl. It is making things complicated.
It is out of scope of this PR. We can do that later.
torch/csrc/cuda/nccl.cpp
Outdated
| static inline void NCCL_CHECK_NONBLOCKING( | ||
| ncclResult_t result, | ||
| ncclComm_t comm) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me this should be the base case. I could be wrong though :)
| for (const auto i : c10::irange(comms.size())) { | ||
| do { | ||
| ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result); | ||
| } while (result == ncclInProgress); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder which one should be the inner loop.
Would it be possible that a comm is hanging, while another already errors out, in which case we would miss catching the error here?
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 jobs have failed, first few of them are: periodic / cuda11.7-py3.10-gcc7-sm86-periodic-dynamo-benchmarks / test (aot_eager_huggingface, 1, 1, linux.g5.4xlarge.nvidia.gpu), periodic / cuda11.7-py3.10-gcc7-sm86-periodic-dynamo-benchmarks / test (dynamic_aot_eager_torchbench, 1, 1, linux.g5.4xlarge.nvidia.gpu), inductor / cuda11.8-py3.10-gcc7-sm86 / test (inductor_torchbench_dynamic, 1, 1, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchmergebot -f "assume inductor failures unrelated" |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchmergebot merge -f "assume inductor failures unrelated" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
#95715 added the functionality to abort `ncclCommInitRankConfig` by specifying `blocking=0` to enable non-blocking behavior. However, calling the `pg._abort()` didn't recover from a stuck `ncclCommInitRankConfig` since the `_abort` method only looked through `devNCCLCommMap_` map and aborted those communicators. Since `ncclCommInitRankConfig` was stuck, the communicator itself wasn't added to the map and the host thread was stuck on this line: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1171. As a result, `_abort` was a no-op. To resolve this issue, I added the communicators to `inProgressCommMap_` as soon as they were created and then removed them once added to `devNCCLCommMap_`. I also added a unit test that was failing without the changes to ProcessGroupNCCL.cpp Pull Request resolved: #103264 Approved by: https://github.com/kwen2501
pytorch#95715 added the functionality to abort `ncclCommInitRankConfig` by specifying `blocking=0` to enable non-blocking behavior. However, calling the `pg._abort()` didn't recover from a stuck `ncclCommInitRankConfig` since the `_abort` method only looked through `devNCCLCommMap_` map and aborted those communicators. Since `ncclCommInitRankConfig` was stuck, the communicator itself wasn't added to the map and the host thread was stuck on this line: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1171. As a result, `_abort` was a no-op. To resolve this issue, I added the communicators to `inProgressCommMap_` as soon as they were created and then removed them once added to `devNCCLCommMap_`. I also added a unit test that was failing without the changes to ProcessGroupNCCL.cpp
pytorch#95715 added the functionality to abort `ncclCommInitRankConfig` by specifying `blocking=0` to enable non-blocking behavior. However, calling the `pg._abort()` didn't recover from a stuck `ncclCommInitRankConfig` since the `_abort` method only looked through `devNCCLCommMap_` map and aborted those communicators. Since `ncclCommInitRankConfig` was stuck, the communicator itself wasn't added to the map and the host thread was stuck on this line: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1171. As a result, `_abort` was a no-op. To resolve this issue, I added the communicators to `inProgressCommMap_` as soon as they were created and then removed them once added to `devNCCLCommMap_`. I also added a unit test that was failing without the changes to ProcessGroupNCCL.cpp
…3925) #95715 added the functionality to abort `ncclCommInitRankConfig` by specifying `blocking=0` to enable non-blocking behavior. However, calling the `pg._abort()` didn't recover from a stuck `ncclCommInitRankConfig` since the `_abort` method only looked through `devNCCLCommMap_` map and aborted those communicators. Since `ncclCommInitRankConfig` was stuck, the communicator itself wasn't added to the map and the host thread was stuck on this line: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1171. As a result, `_abort` was a no-op. To resolve this issue, I added the communicators to `inProgressCommMap_` as soon as they were created and then removed them once added to `devNCCLCommMap_`. I also added a unit test that was failing without the changes to ProcessGroupNCCL.cpp Pull Request resolved: #103925 Approved by: https://github.com/osalpekar
Support for nonblocking NCCL communicators/fault tolerance/checking which was added in 2.14 as an experimental feature.
Enabled via the environment variable:
CC @ptrblck