-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 Feature
with @pritamdamania87 @mrshenli @zhaojuanmao
This RFC is to summarize the current proposal for supporting uneven inputs across different DDP processes. Related discussion in #33148. An example pain point from a user is on the PyTorch forums.
Problem
torch.nn.parallel.DistributedDataParallel is a commonly used tool for distributed data-parallel training, but currently obliges the user to provide an equal number of inputs across each participating DDP process (or appropriately handle the error otherwise). DDP currently fails when different processes have an unequal number of inputs to process during training. While there are utilities such as DataLoader and DistributedSampler that make navigating this assumption in DDP easier by evenly distributing the dataset, we can't expect these to solve all use cases and many users have had use cases where uneven inputs need to be supported.
The following script gives a simple example of the error:
import torch
import torch.distributed as dist
import os
import torch.multiprocessing as mp
import torch.nn as nn
def worker(rank):
dist.init_process_group("nccl", rank=rank, world_size=2)
torch.cuda.set_device(rank)
model = nn.Linear(1, 1, bias=False).to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
# Create uneven inputs, rank 1 will get one more input than rank 0. This will cause a hang.
inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
for _ in range(5):
for inp in inputs:
loss = model(inp).sum()
loss.backward()
torch.cuda.synchronize(device=rank)
if __name__ == '__main__':
os.environ["MASTER_ADDR"] = "localhost" ; os.environ["MASTER_PORT"] = "29501"
mp.spawn(worker, nprocs=2, args=())
With the NCCL backend (recommended choice when training with multiple GPUs) this will result in a hang, as process 1 will wait for communication (allreduce) from process 0, but process 0 has already exited its training loop. On the other hand, with the gloo backend, this results in a "Connection reset by peer" error.
Proposal
This was proposed by @pritamdamania87 and is influenced by the approach taken by Horovod to resolve a similar problem (horovod/horovod#832).
-
Provide a context manager such as
with torch.nn.parallel.distributed.join(). In the__enter__, we will set a flag indicating that we will run the below process for managing uneven inputs. -
The context manager's
__exit__indicates that the process has depleted its input and is ready to join. When a trainer calls this:
a. Schedule an allreduce withtorch.tensor(0). This allreduce will match the allreduce scheduled by non-joined processes (explained below in point 3)
b. If the result of the above is zero, this means that all processes have depleted their inputs and we can move to step (d)
c. Otherwise, schedule an allreduce for all buckets in the backwards pass, with all gradients zeroed out (this is so that joined ranks don't affect gradients of the rest of the training processes). This will match the allreduce done in the backwards pass for currently active trainers. Go back to step a.
d. If (a) returns all zeros, this means that all ranks have terminated their inputs and we can move on to cleanup. We also need to keep track of the process with the latest model parameters, and broadcast them to all ranks to maintain the fact that DDP ensures all parameters across processes are the same. We can do this via a simple version counter. In this step, we can then allgather this version counter, and have the process with the maximum counter broadcast its parameters to the rest of the processes. Ties can be broken arbitrarily. -
If a trainer has not called
__exit__, then:
a. Before scheduling allreduces for the backwards pass, schedule an allreduce withtorch.tensor(1). This allreduce matches the one scheduled in (2a). We can schedule this allreduce in the forward pass, but we should not await it here for performance reasons; it should be awaited at the end of the backwards pass.
b. Schedule allreduce ops for all the buckets as typical in the backwards pass for DDP. Processes which have depleted their inputs will match these allreduces as a result of step 2c. These processes will have zero as the argument for their gradients so they will not contribute to gradient averaging.
c. Instead of dividing by a static world_size, since we now can have a smaller effective world size (initial_world_size - currently_joined_processes), divide by this instead to ensure that we are still correctly averaging gradients. This can be done by taking the value returned in 3a, which will be interpreted as an int representing the number of currently active processes.
Code sample
model = nn.Linear(1, 1, bias=False).to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
# Create uneven inputs
inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
for _ in range(epochs):
with torch.nn.parallel.distributed.join():
for inp in inputs:
loss = model(inp).sum()
loss.backward()
Alternatives considered
We considered the alternative of all trainers raising a StopIteration once we detect that at least one trainer has depleted its input (via the above method). However, the user would then have to catch this StopIteration and this would also result in all processes stopping their training, whereas the currently proposed method allows training to continue with a smaller effective world size. In the future if we see the need for users to actually stop the training early in these situations, we can provide the appropriate options, although we would like to keep usage of this API as simple as reasonably possible.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar