-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Support accumulating DDP grads using a context manager #21736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 2 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| from contextlib import contextmanager | ||
| import copy | ||
| import itertools | ||
|
|
||
|
|
@@ -272,6 +273,8 @@ def __init__(self, module, device_ids=None, | |
| self.module = module | ||
| self.broadcast_buffers = broadcast_buffers | ||
| self.find_unused_parameters = find_unused_parameters | ||
| self.require_backward_grad_sync = True | ||
| self.require_forward_param_sync = True | ||
|
|
||
| if check_reduction: | ||
| # This argument is no longer used since the reducer | ||
|
|
@@ -377,8 +380,31 @@ def _check_default_group(self): | |
| "init_process_group and have not passed " | ||
| "process_group argument to DDP constructor") | ||
|
|
||
| @contextmanager | ||
| def no_sync(self): | ||
| r""" | ||
| A context manager to disable gradient synchronizations across DDP | ||
| processes. Within this context, gradients will be accumulated on module | ||
| variables, which will later be synchronized in the first | ||
| forward-backward pass exiting the context. | ||
|
|
||
| Example:: | ||
|
|
||
| >>> ddp = torch.nn.DistributedDataParallel(model, pg) | ||
| >>> with ddp.no_sync(): | ||
| ... for input in inputs: | ||
| ... ddp(input).backward() # no synchronization, accumulate grads | ||
| ... ddp(another_input).backward() # synchronize grads | ||
| """ | ||
| old_require_backward_grad_sync = self.require_backward_grad_sync | ||
| self.require_backward_grad_sync = False | ||
| yield | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be in a |
||
| self.require_backward_grad_sync = old_require_backward_grad_sync | ||
|
|
||
| def forward(self, *inputs, **kwargs): | ||
| self._sync_params() | ||
| if self.require_forward_param_sync: | ||
| self._sync_params() | ||
|
|
||
| if self.device_ids: | ||
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) | ||
| if len(self.device_ids) == 1: | ||
|
|
@@ -389,7 +415,8 @@ def forward(self, *inputs, **kwargs): | |
| else: | ||
| output = self.module(*inputs, **kwargs) | ||
|
|
||
| if torch.is_grad_enabled(): | ||
| if torch.is_grad_enabled() and self.require_backward_grad_sync: | ||
| self.require_forward_param_sync = True | ||
| # We'll return the output object verbatim since it is a freeform | ||
| # object. We need to find any tensors in this object, though, | ||
| # because we need to figure out which parameters were used during | ||
|
|
@@ -399,6 +426,9 @@ def forward(self, *inputs, **kwargs): | |
| self.reducer.prepare_for_backward(list(_find_tensors(output))) | ||
| else: | ||
| self.reducer.prepare_for_backward([]) | ||
| else: | ||
| self.require_forward_param_sync = False | ||
|
|
||
| return output | ||
|
|
||
| def scatter(self, inputs, kwargs, device_ids): | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if
DistributedDataParallelis picklable, but if it is, then you should add a__setstate__that adds those two attributes, because otherwise people who load older checkpoints will get missing attribute errors.