File tree Expand file tree Collapse file tree 2 files changed +41
-2
lines changed
Expand file tree Collapse file tree 2 files changed +41
-2
lines changed Original file line number Diff line number Diff line change @@ -209,6 +209,35 @@ def _test_ddp_parity(
209209 self .assertEqual (n1 , n2 )
210210 self .assertEqual (p1 , p2 )
211211
212+ if offload_params :
213+ # TODO: Gradient computation on CPU and GPU differ slightly causing
214+ # drift unrelated to `clip_grad_norm_()`.
215+ # https://github.com/pytorch/pytorch/issues/89133
216+ return
217+
218+ # Run a few more iterations
219+ # TODO: We cannot run too many iterations, or else there is drift:
220+ # https://github.com/pytorch/pytorch/issues/89136
221+ for i in range (3 ):
222+ set_to_none = i % 2 == 0 # exercise both
223+ ddp_optim .zero_grad (set_to_none = set_to_none )
224+ fsdp_optim .zero_grad (set_to_none = set_to_none )
225+ inp = ddp_model .module .get_input (device )
226+ for model in (ddp_model , fsdp_model ):
227+ out = model (* inp )
228+ out .sum ().backward ()
229+ ddp_total_norm = torch .nn .utils .clip_grad_norm_ (
230+ ddp_model .parameters (),
231+ max_norm = max_norm ,
232+ norm_type = norm_type ,
233+ )
234+ fsdp_total_norm = fsdp_model .clip_grad_norm_ (
235+ max_norm = max_norm , norm_type = norm_type
236+ )
237+ self .assertEqual (ddp_total_norm , fsdp_total_norm )
238+ ddp_optim .step ()
239+ fsdp_optim .step ()
240+
212241
213242instantiate_parametrized_tests (TestClipGradNorm )
214243
Original file line number Diff line number Diff line change @@ -1161,10 +1161,20 @@ def clip_grad_norm_(
11611161 self ._streams ["unshard" ],
11621162 self ._streams ["pre_unshard" ],
11631163 )
1164+ # If every FSDP instance uses `NO_SHARD`, then we can directly use
1165+ # the normal `nn.utils` one targeting local gradients
1166+ all_no_shard = all (
1167+ not handle .uses_sharded_strategy
1168+ for handle in FullyShardedDataParallel ._fsdp_handles (self )
1169+ )
1170+ if all_no_shard :
1171+ return torch .nn .utils .clip_grad_norm_ (
1172+ self .parameters (), max_norm , norm_type
1173+ )
1174+ # Otherwise, there exists some FSDP instance using a sharded strategy,
1175+ # where sharded and non-sharded parameters must be handled separately
11641176 max_norm = float (max_norm )
11651177 norm_type = float (norm_type )
1166- # Perform local gradient norm computation, where sharded and
1167- # non-sharded parameters must be handled separately
11681178 sharded_params = set ()
11691179 nonsharded_params = set () # `NO_SHARD` or not FSDP-managed
11701180 for handle in FullyShardedDataParallel ._fsdp_handles (self ):
You can’t perform that action at this time.
0 commit comments