-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Motivation
DistributedDataParallel (DDP) training on GPUs using the NCCL process group routinely hangs, which is an unpleasant experience for users of PyTorch Distributed. In various situations (desynchronizations, high GPU utilization, etc.), one of the ranks in the process group may be slower to call the collective than the remaining ranks. Often, the slow rank is blocked on a previous CUDA operation and due to high GPU utilization, it is not able to proceed to calling the collective. Meanwhile, all the remaining ranks block, waiting for the stuck rank, which causes the entire training to hang indefinitely. The training can often stay in this state for hours or until it is manually killed by the user.
Alternatives
Ideally we would like to provide a mechansim to detect and recover from hanging without any performance overhead. One
feature already exists to detect hangs: NCCL_BLOCKING_WAIT. Blocking wait blocks the main thread when the
wait function on the associated WorkNCCL object is called, and this wait function polls every fixed time interval whether or not the collective has timed out. If the collective has timed out, an exception is thrown from the main thread. However, due to
blocking the main thread, this approach may incur up to a 60% regression on training performance.
Pitch
An analysis of blocking wait functionality suggests error handling and timeout checking must happen asynchronously. The existing ncclCommWatchdogThread polls for NCCL errors at some fixed duration and aborts the associated NCCL communicators so that future NCCL functions do not operate on corrupted data. We can additionally make the ncclCommWatchdogThread check for timed out collectives and set an appropriate exception on the WorkNCCL objects associated with the collectives if necessary.
We cannot surface exceptions set on the WorkNCCL objects by blocking the main thread or using some other trigger (such as when the next collective is called) since this may incur a large performance overhead and may not work for all workloads. As a result, we introduce a new helper thread, the workCleanupThread. Every time a collective is called, we add its WorkNCCL
object to a list. The workCleanupThread then iterates through this list of ongoing collectives. We check whether collectives
have completed successfully using cudaEventQuery and remove those objects from the list. For WorkNCCL objects that
have an exception set (which may have been set due to errors or timeouts set by the watchdog), we rethrow the exception.
Since this exception is being thrown from a helper thread, the training process will crash.
Due to the asynchronous nature of detecting and surfacing errors, this feature has little to no performance overhead for DDP
training on even the most complex models.
Usage
To enable this feature, set the environment variable NCCL_ASYNC_ERROR_HANDLING to 1. The timeout after which stuck
collectives are aborted can be configured when initializing the process group:
import torch.distributed as dist
dist.init_process_group(
…
backend=“nccl”,
timeout=timedelta(seconds=30) # Set your desired timeout here. The default is 30 minutes.
)
Using this feature by itself (while using DDP for training with NCCL) allows users to abort stuck collectives and thereby save
compute time that would otherwise have been wasted due to the hanging. However, using this feature along with torchelastic
allows training to continue even after the hang. This feature will crash the training process after detecting a stuck collective, and torchelastic will see the SIGABRT from the training process and restart training from the last checkpoint. This provides a comprehensive method for detecting and recovering from hangs with little performance overhead.
Lastly, this feature is separate from NCCL_BLOCKING_WAIT, so only one of these two environment variables should be set during training.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski