Skip to content

Commit ee1d375

Browse files
Andrew Gupytorchmergebot
authored andcommitted
[FSDP] Add fast path for NO_SHARD clip_grad_norm_() (#89137)
Pull Request resolved: #89137 Approved by: https://github.com/rohan-varma
1 parent e70f446 commit ee1d375

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

test/distributed/fsdp/test_fsdp_clip_grad_norm.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,35 @@ def _test_ddp_parity(
209209
self.assertEqual(n1, n2)
210210
self.assertEqual(p1, p2)
211211

212+
if offload_params:
213+
# TODO: Gradient computation on CPU and GPU differ slightly causing
214+
# drift unrelated to `clip_grad_norm_()`.
215+
# https://github.com/pytorch/pytorch/issues/89133
216+
return
217+
218+
# Run a few more iterations
219+
# TODO: We cannot run too many iterations, or else there is drift:
220+
# https://github.com/pytorch/pytorch/issues/89136
221+
for i in range(3):
222+
set_to_none = i % 2 == 0 # exercise both
223+
ddp_optim.zero_grad(set_to_none=set_to_none)
224+
fsdp_optim.zero_grad(set_to_none=set_to_none)
225+
inp = ddp_model.module.get_input(device)
226+
for model in (ddp_model, fsdp_model):
227+
out = model(*inp)
228+
out.sum().backward()
229+
ddp_total_norm = torch.nn.utils.clip_grad_norm_(
230+
ddp_model.parameters(),
231+
max_norm=max_norm,
232+
norm_type=norm_type,
233+
)
234+
fsdp_total_norm = fsdp_model.clip_grad_norm_(
235+
max_norm=max_norm, norm_type=norm_type
236+
)
237+
self.assertEqual(ddp_total_norm, fsdp_total_norm)
238+
ddp_optim.step()
239+
fsdp_optim.step()
240+
212241

213242
instantiate_parametrized_tests(TestClipGradNorm)
214243

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,10 +1161,20 @@ def clip_grad_norm_(
11611161
self._streams["unshard"],
11621162
self._streams["pre_unshard"],
11631163
)
1164+
# If every FSDP instance uses `NO_SHARD`, then we can directly use
1165+
# the normal `nn.utils` one targeting local gradients
1166+
all_no_shard = all(
1167+
not handle.uses_sharded_strategy
1168+
for handle in FullyShardedDataParallel._fsdp_handles(self)
1169+
)
1170+
if all_no_shard:
1171+
return torch.nn.utils.clip_grad_norm_(
1172+
self.parameters(), max_norm, norm_type
1173+
)
1174+
# Otherwise, there exists some FSDP instance using a sharded strategy,
1175+
# where sharded and non-sharded parameters must be handled separately
11641176
max_norm = float(max_norm)
11651177
norm_type = float(norm_type)
1166-
# Perform local gradient norm computation, where sharded and
1167-
# non-sharded parameters must be handled separately
11681178
sharded_params = set()
11691179
nonsharded_params = set() # `NO_SHARD` or not FSDP-managed
11701180
for handle in FullyShardedDataParallel._fsdp_handles(self):

0 commit comments

Comments
 (0)