File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -390,6 +390,8 @@ def __setstate__(self, state):
390390 # If serializable, then the process group should be the default one
391391 self .process_group = _get_default_group ()
392392 super (DistributedDataParallel , self ).__setstate__ (state )
393+ self .__dict__ .setdefault ('require_forward_param_sync' , True )
394+ self .__dict__ .setdefault ('require_backward_grad_sync' , True )
393395 self ._ddp_init_helper ()
394396
395397 def _check_default_group (self ):
@@ -425,8 +427,10 @@ def no_sync(self):
425427 """
426428 old_require_backward_grad_sync = self .require_backward_grad_sync
427429 self .require_backward_grad_sync = False
428- yield
429- self .require_backward_grad_sync = old_require_backward_grad_sync
430+ try :
431+ yield
432+ finally :
433+ self .require_backward_grad_sync = old_require_backward_grad_sync
430434
431435 def forward (self , * inputs , ** kwargs ):
432436 if self .require_forward_param_sync :
You can’t perform that action at this time.
0 commit comments