Skip to content

Commit ab9f9de

Browse files
author
Andrew Gu
committed
Update on "[FSDP] Add fast path for NO_SHARD clip_grad_norm_()"
[ghstack-poisoned]
1 parent f83b576 commit ab9f9de

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

test/distributed/fsdp/test_fsdp_clip_grad_norm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,10 @@ def _test_ddp_parity(
218218
# Run a few more iterations
219219
# TODO: We cannot run too many iterations, or else there is drift:
220220
# https://github.com/pytorch/pytorch/issues/89136
221-
for _ in range(3):
222-
ddp_optim.zero_grad(set_to_none=True)
223-
fsdp_optim.zero_grad(set_to_none=True)
221+
for i in range(3):
222+
set_to_none = i % 2 == 0 # exercise both
223+
ddp_optim.zero_grad(set_to_none=set_to_none)
224+
fsdp_optim.zero_grad(set_to_none=set_to_none)
224225
inp = ddp_model.module.get_input(device)
225226
for model in (ddp_model, fsdp_model):
226227
out = model(*inp)

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,8 @@ def clip_grad_norm_(
11611161
self._streams["unshard"],
11621162
self._streams["pre_unshard"],
11631163
)
1164-
# Check for an early return if every FSDP instance uses `NO_SHARD`
1164+
# If every FSDP instance uses `NO_SHARD`, then we can directly use
1165+
# the normal `nn.utils` one targeting local gradients
11651166
all_no_shard = all(
11661167
not handle.uses_sharded_strategy
11671168
for handle in FullyShardedDataParallel._fsdp_handles(self)

0 commit comments

Comments
 (0)