Skip to content

Commit 08cf7b5

Browse files
committed
Join-based API to support DDP uneven inputs
Pull Request resolved: #42577 Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes: #### Approach 1) Add a context manager `torch.nn.parallel.distributed.join` 2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined. 3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros 4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness. 5) At the end of training, the last joined process is selected to be the "authoritative" model copy We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes. #### How is it tested? We have tests covering the following models/scenarios: - [x] Simple linear model - [x] Large convolutional model - [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm) - [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm) - [x] Model with unused params (with find unused parameters=True) - [x] Scenarios where different processes iterate for a varying number of different iterations. - [x] Test consistency in tie-breaking when multiple ranks are the last ones to join - [x] Test that we divide by the effective world_size (no. of unjoined processes) #### Limitations 1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP. 2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought. 3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs. ghstack-source-id: 109227369 Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!
1 parent 8850fd1 commit 08cf7b5

File tree

9 files changed

+705
-51
lines changed

9 files changed

+705
-51
lines changed

test/distributed/test_c10d.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3544,7 +3544,7 @@ def op_timeout_sec(self):
35443544
def world_size(self):
35453545
return 2
35463546

3547-
def _test_broadcast_coalesced(self, process_group, device):
3547+
def _test_broadcast_coalesced(self, process_group, device, root_rank):
35483548
half = torch.float16
35493549

35503550
# No support for float16 for CPU tensors
@@ -3560,25 +3560,32 @@ def _test_broadcast_coalesced(self, process_group, device):
35603560

35613561
# The tensors to pass to broadcast are idential to the target
35623562
# only on the process that is the root of the broadcast.
3563-
if self.rank == 0:
3563+
if self.rank == root_rank:
35643564
tensors = list(tensor.clone() for tensor in target)
35653565
else:
3566-
tensors = list(torch.empty_like(tensor) for tensor in target)
3566+
tensors = list(torch.zeros_like(tensor) for tensor in target)
3567+
3568+
if self.rank != root_rank:
3569+
self.assertNotEqual(tensors, target)
35673570

35683571
c10d._broadcast_coalesced(
35693572
process_group,
35703573
tensors,
3571-
buffer_size=256)
3574+
buffer_size=256,
3575+
authoritative_rank=root_rank)
35723576

3573-
self.assertEqual(tensors, target)
3577+
if self.rank != root_rank:
3578+
self.assertEqual(tensors, target)
35743579

35753580
@requires_nccl()
35763581
@skip_if_not_multigpu
35773582
def test_broadcast_coalesced_nccl(self):
35783583
store = c10d.FileStore(self.file_name, self.world_size)
35793584
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
35803585
device = torch.device('cuda:%d' % self.rank)
3581-
self._test_broadcast_coalesced(process_group, device)
3586+
ranks = list(range(self.world_size()))
3587+
for root_rank in ranks:
3588+
self._test_broadcast_coalesced(process_group, device, root_rank)
35823589

35833590
@requires_gloo()
35843591
@skip_if_not_multigpu
@@ -3588,7 +3595,9 @@ def test_broadcast_coalesced_gloo_cuda(self):
35883595
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
35893596
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
35903597
device = torch.device('cuda:%d' % self.rank)
3591-
self._test_broadcast_coalesced(process_group, device)
3598+
ranks = list(range(self.world_size()))
3599+
for root_rank in ranks:
3600+
self._test_broadcast_coalesced(process_group, device, root_rank)
35923601

35933602
@requires_gloo()
35943603
def test_broadcast_coalesced_gloo_cpu(self):
@@ -3597,7 +3606,9 @@ def test_broadcast_coalesced_gloo_cpu(self):
35973606
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
35983607
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
35993608
device = torch.device('cpu')
3600-
self._test_broadcast_coalesced(process_group, device)
3609+
ranks = list(range(self.world_size()))
3610+
for root_rank in ranks:
3611+
self._test_broadcast_coalesced(process_group, device, root_rank)
36013612

36023613

36033614
if __name__ == '__main__':

test/distributed/test_distributed.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
22
import copy
3+
from dataclasses import dataclass
34
import errno
45
import fcntl
6+
import itertools
57
import math
68
import os
79
import random
@@ -13,6 +15,7 @@
1315
from datetime import timedelta
1416
from functools import reduce, wraps
1517
from io import StringIO
18+
from typing import Union
1619

1720
import torch
1821
import torch.cuda
@@ -85,6 +88,14 @@ def forward(self, x):
8588
x = self.fc3(x)
8689
return F.softmax(x, dim=1)
8790

91+
class Task(nn.Module):
92+
def __init__(self):
93+
super().__init__()
94+
self.p = nn.Parameter(torch.ones(2, 2))
95+
96+
def forward(self, x):
97+
return self.p + x
98+
8899

89100
class BatchNormNet(nn.Module):
90101

@@ -2437,6 +2448,245 @@ def validate_global_samples(local_num_samples):
24372448
# Ensure that each rank processes the same number of samples.
24382449
validate_global_samples(local_num_samples)
24392450

2451+
@require_backend({"gloo", "nccl"})
2452+
@require_backends_available({"gloo", "nccl"})
2453+
@skip_if_lt_x_gpu(2)
2454+
@skip_if_rocm
2455+
def test_ddp_sync_params_and_buffers(self):
2456+
# Test that after calling _sync_params_and_buffers, models across ranks
2457+
# are the same and are equal to the model on the input rank.
2458+
dim = 2
2459+
rank = self.rank
2460+
rank_to_broadcast = 1
2461+
# Seed to ensure that ranks are initialized with different initial models.
2462+
torch.manual_seed(rank)
2463+
model = nn.Linear(dim, dim, bias=False)
2464+
net = torch.nn.parallel.DistributedDataParallel(
2465+
model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
2466+
)
2467+
new_model = nn.Linear(dim, dim, bias=False).cuda(rank)
2468+
net.module = copy.deepcopy(new_model)
2469+
# Assert params are different
2470+
net_module_states = list(net.module.state_dict().values())
2471+
for t in net_module_states:
2472+
tensor_list = [
2473+
torch.zeros_like(t).to(self.rank) for _ in range(dist.get_world_size())
2474+
]
2475+
dist.all_gather(tensor_list, t)
2476+
for i, tensor in enumerate(tensor_list):
2477+
if i == rank:
2478+
self.assertEqual(t, tensor)
2479+
else:
2480+
# tensor from another rank should be different.
2481+
self.assertNotEqual(t, tensor)
2482+
2483+
net._sync_params_and_buffers(authoritative_rank=rank_to_broadcast)
2484+
# Now all model params should be the same.
2485+
net_module_states = list(net.module.state_dict().values())
2486+
for t in net_module_states:
2487+
tensor_list = [
2488+
torch.zeros_like(t).to(self.rank) for _ in range(dist.get_world_size())
2489+
]
2490+
dist.all_gather(tensor_list, t)
2491+
for tensor in tensor_list:
2492+
self.assertEqual(tensor, t)
2493+
# Since the network params were broadcast from rank 1, validate that
2494+
# they are the same as new_model on rank 1.
2495+
if rank == rank_to_broadcast:
2496+
expected_states = new_model.state_dict().values()
2497+
for t, expected in zip(net_module_states, expected_states):
2498+
self.assertEqual(t, expected)
2499+
2500+
@require_backend({"gloo", "nccl"})
2501+
@require_backends_available({"gloo", "nccl"})
2502+
@skip_if_lt_x_gpu(2)
2503+
@skip_if_rocm
2504+
def test_ddp_grad_div_uneven_inputs(self):
2505+
# Test that we scale by the effective world size when allreducing grads.
2506+
# For example if N processes start DDP training but 0 < K < N join, we
2507+
# should divide by N - K and not N.
2508+
dim = 5
2509+
batch = 1
2510+
grad_scale = 50
2511+
rank = self.rank
2512+
model = nn.Linear(dim, dim, bias=False)
2513+
inp = torch.ones(batch, dim).to(self.rank) * grad_scale
2514+
net = torch.nn.parallel.DistributedDataParallel(
2515+
model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
2516+
)
2517+
n_iters = 1 if self.rank == 0 else 2
2518+
from torch.nn.parallel.distributed import join
2519+
2520+
with join(net) as ddp_join:
2521+
for i in range(n_iters):
2522+
loss = net(inp).sum()
2523+
loss.backward()
2524+
# The grad is always expected_grad, since we divide by the number
2525+
# of currently active processes and inactive processes contribute
2526+
# zero gradient. If we kept dividing by static initial world
2527+
# size as processes leave, the grad would be smaller.
2528+
expected_grad = torch.ones(dim, dim).to(self.rank) * grad_scale
2529+
param = list(net.parameters())[0]
2530+
self.assertEqual(expected_grad, param.grad)
2531+
# Avoid accumulating grads so that it's the same every iteration
2532+
net.zero_grad()
2533+
torch.cuda.synchronize(device=self.rank)
2534+
2535+
def _run_uneven_inputs_test(
2536+
self, test_case, num_iters, iteration_offset, find_unused_params
2537+
):
2538+
from torch.nn.parallel.distributed import join
2539+
2540+
model = test_case.model
2541+
inp = test_case.inp
2542+
rank = self.rank
2543+
# Bucket_cap_mb is intentionally low to test allreduce scheduling when
2544+
# there are many buckets.
2545+
net = torch.nn.parallel.DistributedDataParallel(
2546+
model.cuda(rank),
2547+
device_ids=[rank],
2548+
bucket_cap_mb=1,
2549+
find_unused_parameters=find_unused_params,
2550+
)
2551+
if rank != 0:
2552+
num_iters += iteration_offset
2553+
2554+
with join(net) as ddp_join:
2555+
for _ in range(num_iters):
2556+
if isinstance(inp, tuple):
2557+
loss = net(*inp).sum()
2558+
else:
2559+
loss = net(inp).sum()
2560+
loss.backward()
2561+
self._model_step(net)
2562+
# Ensure completion of GPU kernels (including allreduce). If the
2563+
# join API is not properly implemented, then this should hang
2564+
# since the allreduce will hang.
2565+
torch.cuda.synchronize(device=rank)
2566+
2567+
# Ensure completion of all GPU kernels.
2568+
torch.cuda.synchronize(device=rank)
2569+
self.assertTrue(ddp_join.authoritative_rank)
2570+
# All ranks should have agreed on the same authoritative_rank!
2571+
final_rank_tensor = torch.tensor([ddp_join.authoritative_rank]).to(self.rank)
2572+
tensor_list = [
2573+
torch.zeros_like(final_rank_tensor).to(self.rank)
2574+
for _ in range(dist.get_world_size())
2575+
]
2576+
dist.all_gather(tensor_list, final_rank_tensor)
2577+
max_rank = dist.get_world_size() - 1
2578+
self.assertSetEqual({max_rank}, set(tensor.item() for tensor in tensor_list))
2579+
# Ensure that all models are the same across ranks after all have joined.
2580+
net_module_states = list(net.module.state_dict().values())
2581+
for t in net_module_states:
2582+
tensor_list = [
2583+
torch.zeros_like(t).to(self.rank) for _ in range(dist.get_world_size())
2584+
]
2585+
dist.all_gather(tensor_list, t)
2586+
for tensor in tensor_list:
2587+
self.assertEqual(t, tensor)
2588+
dist.barrier()
2589+
2590+
@require_backend({"gloo", "nccl"})
2591+
@require_backends_available({"gloo", "nccl"})
2592+
@skip_if_lt_x_gpu(2)
2593+
@skip_if_rocm
2594+
def test_ddp_uneven_inputs(self):
2595+
@dataclass
2596+
class DDPUnevenTestInput:
2597+
name: str
2598+
model: nn.Module
2599+
inp: Union[torch.tensor, tuple]
2600+
2601+
dim = 1000
2602+
batch = 1
2603+
# Create a variety of models to run uneven input tests on.
2604+
large_model = nn.Sequential(
2605+
nn.Conv2d(1, 20, 5),
2606+
nn.ReLU(),
2607+
nn.Conv2d(20, 64, 5),
2608+
nn.ReLU(),
2609+
nn.Conv2d(64, 32, 5),
2610+
nn.ReLU(),
2611+
nn.Conv2d(32, 128, 5),
2612+
nn.ReLU(),
2613+
nn.Conv2d(128, 256, 5),
2614+
nn.ReLU(),
2615+
)
2616+
resnet_model = torchvision.models.resnet50()
2617+
small_model = nn.Linear(dim, dim, bias=False)
2618+
2619+
class UnusedParamModule(nn.Module):
2620+
def __init__(self):
2621+
super().__init__()
2622+
2623+
class FindUnusedParamModule(nn.Module):
2624+
def __init__(self, unused_params_rank):
2625+
super(FindUnusedParamModule, self).__init__()
2626+
self.t0 = Task()
2627+
self.t1 = Task()
2628+
self.unused_params_rank = unused_params_rank
2629+
2630+
def task_parameters(self):
2631+
return (self.t0.p, self.t1.p)
2632+
2633+
def forward(self, x, rank):
2634+
return (
2635+
self.t1(self.t0(x))
2636+
if rank != self.unused_params_rank
2637+
else self.t1(x)
2638+
)
2639+
2640+
unjoined_rank_with_unused_params_model = FindUnusedParamModule(1)
2641+
joined_rank_with_unused_params_model = FindUnusedParamModule(0)
2642+
2643+
rank = self.rank
2644+
models_to_test = [
2645+
DDPUnevenTestInput(
2646+
name="large_conv_model",
2647+
model=large_model,
2648+
inp=torch.ones(batch, batch, dim, dim).to(rank),
2649+
),
2650+
DDPUnevenTestInput(
2651+
name="resnet_model",
2652+
model=resnet_model,
2653+
inp=torch.ones(1, 3, 1000, 1000),
2654+
),
2655+
DDPUnevenTestInput(
2656+
name="small_model",
2657+
model=small_model,
2658+
inp=torch.ones(batch, dim).to(rank),
2659+
),
2660+
# Unused parameter test where rank that does not join early has unused params
2661+
DDPUnevenTestInput(
2662+
name="unjoined_rank_with_unused_params_model",
2663+
model=unjoined_rank_with_unused_params_model,
2664+
inp=(torch.ones(batch, 2).to(rank), rank),
2665+
),
2666+
# Unused parameter test where rank that does join early has unused params
2667+
DDPUnevenTestInput(
2668+
name="joined_rank_with_unused_params_model",
2669+
model=joined_rank_with_unused_params_model,
2670+
inp=(torch.ones(batch, 2).to(rank), rank),
2671+
),
2672+
]
2673+
# 0 iteration tests for when one process does not train model at all, so
2674+
# we must shadow the broadcast calls made when rebuilding buckets.
2675+
baseline_num_iters = [0, 5]
2676+
iteration_offsets = [1, 3, 10]
2677+
for (test_case, offset, baseline_num_iter) in itertools.product(
2678+
models_to_test, iteration_offsets, baseline_num_iters
2679+
):
2680+
print(
2681+
f"Running test: {test_case.name} with n_iters {baseline_num_iter} and iteration offset {offset}"
2682+
)
2683+
self._run_uneven_inputs_test(
2684+
test_case,
2685+
baseline_num_iter,
2686+
offset,
2687+
find_unused_params=("unused_params_model" in test_case.name),
2688+
)
2689+
24402690

24412691
if BACKEND == "gloo" or BACKEND == "nccl":
24422692
WORLD_SIZE = os.environ["WORLD_SIZE"]

torch/csrc/distributed/c10d/comm.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ class BroadcastWork {
1414
public:
1515
BroadcastWork(
1616
const std::shared_ptr<c10d::ProcessGroup>& process_group,
17-
std::vector<at::Tensor> bucket_tensors)
17+
std::vector<at::Tensor> bucket_tensors,
18+
int rank = 0)
1819
: bucket_tensors_(std::move(bucket_tensors)),
19-
flat_tensor_({torch::utils::flatten_dense_tensors(bucket_tensors_)}),
20-
work_(process_group->broadcast(flat_tensor_)) {}
20+
flat_tensor_({torch::utils::flatten_dense_tensors(bucket_tensors_)}) {
21+
BroadcastOptions broadcastOptions;
22+
broadcastOptions.rootRank = rank;
23+
work_ =process_group->broadcast(flat_tensor_, broadcastOptions);
24+
}
2125

2226
void finish() {
2327
work_->wait();
@@ -51,7 +55,8 @@ class BroadcastWork {
5155
void broadcast_coalesced(
5256
std::shared_ptr<c10d::ProcessGroup> process_group,
5357
at::TensorList tensors,
54-
size_t buffer_size) {
58+
size_t buffer_size,
59+
int rank) {
5560
// Coalesce tensors into buckets taking into account the maximum buffer size.
5661
// This routine is multi-device aware, so the tensors can be split across
5762
// multiple devices and can contain a mix of CPU and CUDA tensors.
@@ -71,7 +76,7 @@ void broadcast_coalesced(
7176
in_flight.pop_front();
7277
}
7378

74-
in_flight.emplace_back(process_group, c10::fmap(bucket, lookup));
79+
in_flight.emplace_back(process_group, c10::fmap(bucket, lookup), rank);
7580
}
7681

7782
while (!in_flight.empty()) {

torch/csrc/distributed/c10d/comm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace c10d {
1212
void broadcast_coalesced(
1313
std::shared_ptr<c10d::ProcessGroup> process_group,
1414
at::TensorList tensors,
15-
size_t buffer_size);
15+
size_t buffer_size, int rank = 0);
1616

1717
// This class passes bucket contents tensor (for multiple replicas) to
1818
// DDP communication hook.

0 commit comments

Comments
 (0)