Skip to content

Commit f177579

Browse files
apaszkefacebook-github-bot
authored andcommitted
Fix minor issues with #21736 (#22074)
Summary: cc mrshenli Pull Request resolved: #22074 Differential Revision: D15965376 Pulled By: mrshenli fbshipit-source-id: 50ff96de6390817d8ea52c04322c6bee3d649b32
1 parent a458989 commit f177579

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torch/nn/parallel/distributed.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ def __setstate__(self, state):
390390
# If serializable, then the process group should be the default one
391391
self.process_group = _get_default_group()
392392
super(DistributedDataParallel, self).__setstate__(state)
393+
self.__dict__.setdefault('require_forward_param_sync', True)
394+
self.__dict__.setdefault('require_backward_grad_sync', True)
393395
self._ddp_init_helper()
394396

395397
def _check_default_group(self):
@@ -425,8 +427,10 @@ def no_sync(self):
425427
"""
426428
old_require_backward_grad_sync = self.require_backward_grad_sync
427429
self.require_backward_grad_sync = False
428-
yield
429-
self.require_backward_grad_sync = old_require_backward_grad_sync
430+
try:
431+
yield
432+
finally:
433+
self.require_backward_grad_sync = old_require_backward_grad_sync
430434

431435
def forward(self, *inputs, **kwargs):
432436
if self.require_forward_param_sync:

0 commit comments

Comments
 (0)