File tree Expand file tree Collapse file tree 2 files changed +6
-4
lines changed
Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -218,9 +218,10 @@ def _test_ddp_parity(
218218 # Run a few more iterations
219219 # TODO: We cannot run too many iterations, or else there is drift:
220220 # https://github.com/pytorch/pytorch/issues/89136
221- for _ in range (3 ):
222- ddp_optim .zero_grad (set_to_none = True )
223- fsdp_optim .zero_grad (set_to_none = True )
221+ for i in range (3 ):
222+ set_to_none = i % 2 == 0 # exercise both
223+ ddp_optim .zero_grad (set_to_none = set_to_none )
224+ fsdp_optim .zero_grad (set_to_none = set_to_none )
224225 inp = ddp_model .module .get_input (device )
225226 for model in (ddp_model , fsdp_model ):
226227 out = model (* inp )
Original file line number Diff line number Diff line change @@ -1161,7 +1161,8 @@ def clip_grad_norm_(
11611161 self ._streams ["unshard" ],
11621162 self ._streams ["pre_unshard" ],
11631163 )
1164- # Check for an early return if every FSDP instance uses `NO_SHARD`
1164+ # If every FSDP instance uses `NO_SHARD`, then we can directly use
1165+ # the normal `nn.utils` one targeting local gradients
11651166 all_no_shard = all (
11661167 not handle .uses_sharded_strategy
11671168 for handle in FullyShardedDataParallel ._fsdp_handles (self )
You can’t perform that action at this time.
0 commit comments