Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion test/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2397,7 +2397,63 @@ def check_no_grads():

@skip_if_not_multigpu
@skip_if_not_nccl
def test_accumulate_gradients(self):
def test_accumulate_gradients_no_sync(self):
# This is the recommended way to implement accumulate grads
int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
devices = list([torch.device('cuda:' + str(i)) for i in int_devices])
store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
global_batch_size = self.world_size
local_batch_size = len(devices)

model, ddp_model, input, target = \
self._prepare_single_device_module(
process_group, devices, devices, global_batch_size)

def step_model(model, input, target):
model.train()
output = model(input)
loss = F.mse_loss(output, target.to(output.device))
loss.backward()

# ensure accumulate grads works with no_grad
with torch.no_grad():
with ddp_model.no_sync():
ddp_model.train()
ddp_model(input)

# check two model parameters over 2 iterations
for iteration in range(2):
# single cpu/gpu training
step_model(model, input, target)

ddp_input = input[self.rank * local_batch_size: (self.rank + 1) * local_batch_size]
ddp_target = target[self.rank * local_batch_size: (self.rank + 1) * local_batch_size]

if iteration % 2 == 0:
# accumulate grads locally when iteration == 0
with ddp_model.no_sync():
step_model(ddp_model, ddp_input, ddp_target)
else:
# sync grads when iteration == 1
step_model(ddp_model, ddp_input, ddp_target)

for i, j in zip(model.parameters(), ddp_model.parameters()):
if iteration % 2 == 0:
self.assertNotEqual(i.grad, j.grad)
else:
self.assertEqual(i.grad, j.grad)

# Shuffle the input so that DDP input is different
torch.manual_seed(1337 + iteration)
input = input[torch.randperm(global_batch_size)]

@skip_if_not_multigpu
@skip_if_not_nccl
def test_accumulate_gradients_module(self):
# This is NOT the recommended way to implement accumulating grads, but
# we would like to make sure DDP does not mess up with the underlying
# module.
int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
devices = list([torch.device('cuda:' + str(i)) for i in int_devices])
store = c10d.FileStore(self.file.name, self.world_size)
Expand Down
34 changes: 32 additions & 2 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
import copy
import itertools

Expand Down Expand Up @@ -272,6 +273,8 @@ def __init__(self, module, device_ids=None,
self.module = module
self.broadcast_buffers = broadcast_buffers
self.find_unused_parameters = find_unused_parameters
self.require_backward_grad_sync = True
self.require_forward_param_sync = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if DistributedDataParallel is picklable, but if it is, then you should add a __setstate__ that adds those two attributes, because otherwise people who load older checkpoints will get missing attribute errors.


if check_reduction:
# This argument is no longer used since the reducer
Expand Down Expand Up @@ -377,8 +380,31 @@ def _check_default_group(self):
"init_process_group and have not passed "
"process_group argument to DDP constructor")

@contextmanager
def no_sync(self):
r"""
A context manager to disable gradient synchronizations across DDP
processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first
forward-backward pass exiting the context.

Example::

>>> ddp = torch.nn.DistributedDataParallel(model, pg)
>>> with ddp.no_sync():
... for input in inputs:
... ddp(input).backward() # no synchronization, accumulate grads
... ddp(another_input).backward() # synchronize grads
"""
old_require_backward_grad_sync = self.require_backward_grad_sync
self.require_backward_grad_sync = False
yield
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in a try ... finally block, because in case of an exception you will fail to restore the flag!

self.require_backward_grad_sync = old_require_backward_grad_sync

def forward(self, *inputs, **kwargs):
self._sync_params()
if self.require_forward_param_sync:
self._sync_params()

if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
Expand All @@ -389,7 +415,8 @@ def forward(self, *inputs, **kwargs):
else:
output = self.module(*inputs, **kwargs)

if torch.is_grad_enabled():
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
Expand All @@ -399,6 +426,9 @@ def forward(self, *inputs, **kwargs):
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
self.require_forward_param_sync = False

return output

def scatter(self, inputs, kwargs, device_ids):
Expand Down