Skip to content

Commit 996bb74

Browse files
Andrew Gupytorchmergebot
authored andcommitted
[FSDP2] Added HSDP grad acc tests and some minor changes (#125479)
This adds HSDP to the existing gradient accumulation tests and includes some minor changes to simplify things a tiny bit. Pull Request resolved: #125479 Approved by: https://github.com/wanchaol ghstack dependencies: #125431
1 parent b96b1e8 commit 996bb74

File tree

4 files changed

+53
-28
lines changed

4 files changed

+53
-28
lines changed

.ci/pytorch/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ test_inductor_distributed() {
322322
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_mlp
323323
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_hsdp
324324
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_transformer_checkpoint_resume
325+
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_gradient_accumulation
325326
pytest test/distributed/_composable/fsdp/test_fully_shard_frozen.py
326327
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype
327328
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype

test/distributed/_composable/fsdp/test_fully_shard_training.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
FSDPTestMultiThread,
4343
MLP,
4444
patch_all_gather,
45+
patch_all_reduce,
4546
patch_reduce_scatter,
4647
test_compiled_fsdp,
4748
)
@@ -650,16 +651,21 @@ def _test_train_shared_params(
650651
class TestFullyShardGradientAccumulation(FSDPTest):
651652
@property
652653
def world_size(self) -> int:
653-
return min(2, torch.cuda.device_count())
654+
return min(4, torch.cuda.device_count())
654655

655656
@skip_if_lt_x_gpu(2)
656657
def test_gradient_accumulation(self):
657658
"""
658659
Tests gradient accumulation with/without gradient reduction and
659660
with/without resharding after backward.
660661
"""
662+
meshes = [init_device_mesh("cuda", (self.world_size,))] # always test FSDP
663+
if self.world_size == 4: # test HSDP too if enough GPUs
664+
shard_size, replicate_size = 2, 2
665+
meshes.append(init_device_mesh("cuda", (replicate_size, shard_size)))
661666
self.run_subtests(
662667
{
668+
"mesh": meshes,
663669
"reshard_after_forward": [True, False, 2],
664670
# "all": disable reduce-scatter for all modules
665671
# "root_only": disable reduce-scatter for root's linear only
@@ -673,6 +679,7 @@ def test_gradient_accumulation(self):
673679

674680
def _test_gradient_accumulation(
675681
self,
682+
mesh: DeviceMesh,
676683
reshard_after_forward: Union[bool, int],
677684
mode: str,
678685
reshard_after_backward: bool,
@@ -692,15 +699,13 @@ def _test_gradient_accumulation(
692699
global_batch_size = local_batch_size * self.world_size
693700
if mode == "some_mlps":
694701
num_mlps_to_disable_reduce_scatter = 2
695-
model = nn.Sequential(
696-
*(
697-
[nn.Linear(lin_dim, lin_dim)]
698-
+ [MLP(lin_dim, torch.device("cpu")) for _ in range(num_mlps)]
699-
)
700-
)
702+
modules = [nn.Linear(lin_dim, lin_dim)]
703+
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
704+
model = nn.Sequential(*modules)
701705
ref_model = copy.deepcopy(model).cuda()
702706
fully_shard_fn = functools.partial(
703707
fully_shard,
708+
mesh=mesh,
704709
reshard_after_forward=reshard_after_forward,
705710
offload_policy=offload_policy,
706711
)
@@ -710,10 +715,11 @@ def _test_gradient_accumulation(
710715
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
711716
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
712717

718+
# TODO: Migrate to `CommDebugMode` once it supports c10d collectives.
713719
orig_all_gather = dist.all_gather_into_tensor
714-
all_gather_count = 0
715720
orig_reduce_scatter = dist.reduce_scatter_tensor
716-
reduce_scatter_count = 0
721+
orig_all_reduce = dist.all_reduce
722+
all_gather_count, reduce_scatter_count, all_reduce_count = 0, 0, 0
717723

718724
def all_gather_with_count(*args, **kwargs):
719725
nonlocal all_gather_count
@@ -725,11 +731,16 @@ def reduce_scatter_with_count(*args, **kwargs):
725731
reduce_scatter_count += 1
726732
return orig_reduce_scatter(*args, **kwargs)
727733

734+
def all_reduce_with_count(*args, **kwargs):
735+
nonlocal all_reduce_count
736+
all_reduce_count += 1
737+
return orig_all_reduce(*args, **kwargs)
738+
728739
torch.manual_seed(1) # same on all ranks
729740
for iter_idx in range(5):
730741
with patch_all_gather(all_gather_with_count), patch_reduce_scatter(
731742
reduce_scatter_with_count
732-
):
743+
), patch_all_reduce(all_reduce_with_count):
733744
for microbatch_idx in range(num_microbatches):
734745
is_last_microbatch = microbatch_idx == num_microbatches - 1
735746
if mode == "all":
@@ -757,10 +768,7 @@ def reduce_scatter_with_count(*args, **kwargs):
757768
* local_batch_size
758769
].detach()
759770
losses: List[torch.Tensor] = []
760-
for _model, _optim, inp in (
761-
(ref_model, ref_optim, global_inp),
762-
(model, optim, local_inp),
763-
):
771+
for _model, inp in ((ref_model, global_inp), (model, local_inp)):
764772
losses.append(_model(inp).sum())
765773
losses[-1].backward()
766774
dist.all_reduce(losses[1]) # partial -> replicated
@@ -779,7 +787,13 @@ def reduce_scatter_with_count(*args, **kwargs):
779787
# Expect additional reduce-scatters for all MLPs
780788
expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1)
781789
self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
782-
reduce_scatter_count = 0
790+
# Exclude the loss all-reduce per microbatch in our training loop
791+
all_reduce_count -= num_microbatches
792+
if mesh.ndim == 2:
793+
self.assertEqual(all_reduce_count, expected_reduce_scatter_count)
794+
else:
795+
self.assertEqual(all_reduce_count, 0)
796+
reduce_scatter_count = all_reduce_count = 0
783797

784798
# Expect one all-gather per MLP plus one for the root's linear in
785799
# the first microbatch's forward
@@ -873,8 +887,7 @@ def _test_1f1b_microbatching(
873887
ref_losses.append(ref_model(inp).sum())
874888
ref_losses[-1].backward()
875889
for param in ref_model.parameters():
876-
dist.all_reduce(param.grad)
877-
param.grad.detach().div_(self.world_size)
890+
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
878891

879892
for loss, ref_loss in zip(losses, ref_losses):
880893
self.assertEqual(loss, ref_loss)

torch/distributed/_composable/fsdp/_fsdp_param_group.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _init_mp_dtypes(self) -> None:
167167
def _init_grad_divide_factors(self):
168168
data_parallel_world_size = 1
169169
data_parallel_world_size *= self.mesh_info.shard_mesh_size
170-
if isinstance(self.mesh_info, HSDPMeshInfo):
170+
if self._is_hsdp:
171171
data_parallel_world_size *= self.mesh_info.replicate_mesh_size
172172
if self._reduce_dtype in (torch.float32, torch.bfloat16):
173173
# Use NCCL's AVG op to divide after reduction since it is more
@@ -348,7 +348,7 @@ def post_backward(self, *unused: Any):
348348
self.device,
349349
self._grad_divide_factors,
350350
self._all_reduce_process_group
351-
if self._should_all_reduce_grads()
351+
if self._is_hsdp and self.all_reduce_grads
352352
else None,
353353
self.comm_ctx.all_reduce_stream,
354354
)
@@ -481,6 +481,10 @@ def _use_post_forward_mesh(self) -> bool:
481481
and self.mesh_info != self.post_forward_mesh_info
482482
)
483483

484+
@property
485+
def _is_hsdp(self) -> bool:
486+
return isinstance(self.mesh_info, HSDPMeshInfo)
487+
484488
@property
485489
def _all_gather_process_group(self) -> dist.ProcessGroup:
486490
mesh_info = (
@@ -493,18 +497,13 @@ def _all_gather_process_group(self) -> dist.ProcessGroup:
493497

494498
@property
495499
def _reduce_scatter_process_group(self) -> dist.ProcessGroup:
496-
mesh_info = self.mesh_info
497-
assert isinstance(mesh_info, FSDPMeshInfo)
498-
return mesh_info.shard_process_group
500+
assert isinstance(self.mesh_info, FSDPMeshInfo)
501+
return self.mesh_info.shard_process_group
499502

500503
@property
501504
def _all_reduce_process_group(self) -> dist.ProcessGroup:
502-
mesh_info = self.mesh_info
503-
assert isinstance(mesh_info, HSDPMeshInfo)
504-
return mesh_info.replicate_process_group
505-
506-
def _should_all_reduce_grads(self) -> bool:
507-
return isinstance(self.mesh_info, HSDPMeshInfo) and self.all_reduce_grads
505+
assert isinstance(self.mesh_info, HSDPMeshInfo)
506+
return self.mesh_info.replicate_process_group
508507

509508

510509
def _get_param_module_infos(

torch/testing/_internal/common_fsdp.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,18 @@ def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
907907
dist.reduce_scatter_tensor = orig_reduce_scatter
908908

909909

910+
@contextlib.contextmanager
911+
def patch_all_reduce(new_all_reduce: Callable):
912+
orig_all_reduce = dist.all_reduce
913+
dist.barrier()
914+
dist.all_reduce = new_all_reduce
915+
try:
916+
yield
917+
finally:
918+
dist.barrier()
919+
dist.all_reduce = orig_all_reduce
920+
921+
910922
@no_type_check
911923
@contextlib.contextmanager
912924
def patch_unshard(new_unshard: Callable):

0 commit comments

Comments
 (0)