Skip to content

Commit 0347872

Browse files
author
Andrew Gu
committed
[FSDP2] Added HSDP grad acc tests and some minor changes
ghstack-source-id: 5332c88 Pull Request resolved: #125479
1 parent d3a7995 commit 0347872

File tree

4 files changed

+53
-28
lines changed

4 files changed

+53
-28
lines changed

.ci/pytorch/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ test_inductor_distributed() {
322322
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_mlp
323323
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_hsdp
324324
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_transformer_checkpoint_resume
325+
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_gradient_accumulation
325326
pytest test/distributed/_composable/fsdp/test_fully_shard_frozen.py
326327
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype
327328
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype

test/distributed/_composable/fsdp/test_fully_shard_training.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
FSDPTestMultiThread,
4242
MLP,
4343
patch_all_gather,
44+
patch_all_reduce,
4445
patch_reduce_scatter,
4546
test_compiled_fsdp,
4647
)
@@ -649,16 +650,21 @@ def _test_train_shared_params(
649650
class TestFullyShardGradientAccumulation(FSDPTest):
650651
@property
651652
def world_size(self) -> int:
652-
return min(2, torch.cuda.device_count())
653+
return min(4, torch.cuda.device_count())
653654

654655
@skip_if_lt_x_gpu(2)
655656
def test_gradient_accumulation(self):
656657
"""
657658
Tests gradient accumulation with/without gradient reduction and
658659
with/without resharding after backward.
659660
"""
661+
meshes = [init_device_mesh("cuda", (self.world_size,))] # always test FSDP
662+
if self.world_size == 4: # test HSDP too if enough GPUs
663+
shard_size, replicate_size = 2, 2
664+
meshes.append(init_device_mesh("cuda", (replicate_size, shard_size)))
660665
self.run_subtests(
661666
{
667+
"mesh": meshes,
662668
"reshard_after_forward": [True, False, 2],
663669
# "all": disable reduce-scatter for all modules
664670
# "root_only": disable reduce-scatter for root's linear only
@@ -672,6 +678,7 @@ def test_gradient_accumulation(self):
672678

673679
def _test_gradient_accumulation(
674680
self,
681+
mesh: DeviceMesh,
675682
reshard_after_forward: Union[bool, int],
676683
mode: str,
677684
reshard_after_backward: bool,
@@ -691,15 +698,13 @@ def _test_gradient_accumulation(
691698
global_batch_size = local_batch_size * self.world_size
692699
if mode == "some_mlps":
693700
num_mlps_to_disable_reduce_scatter = 2
694-
model = nn.Sequential(
695-
*(
696-
[nn.Linear(lin_dim, lin_dim)]
697-
+ [MLP(lin_dim, torch.device("cpu")) for _ in range(num_mlps)]
698-
)
699-
)
701+
modules = [nn.Linear(lin_dim, lin_dim)]
702+
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
703+
model = nn.Sequential(*modules)
700704
ref_model = copy.deepcopy(model).cuda()
701705
fully_shard_fn = functools.partial(
702706
fully_shard,
707+
mesh=mesh,
703708
reshard_after_forward=reshard_after_forward,
704709
offload_policy=offload_policy,
705710
)
@@ -709,10 +714,11 @@ def _test_gradient_accumulation(
709714
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
710715
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
711716

717+
# TODO: Migrate to `CommDebugMode` once it supports c10d collectives.
712718
orig_all_gather = dist.all_gather_into_tensor
713-
all_gather_count = 0
714719
orig_reduce_scatter = dist.reduce_scatter_tensor
715-
reduce_scatter_count = 0
720+
orig_all_reduce = dist.all_reduce
721+
all_gather_count, reduce_scatter_count, all_reduce_count = 0, 0, 0
716722

717723
def all_gather_with_count(*args, **kwargs):
718724
nonlocal all_gather_count
@@ -724,11 +730,16 @@ def reduce_scatter_with_count(*args, **kwargs):
724730
reduce_scatter_count += 1
725731
return orig_reduce_scatter(*args, **kwargs)
726732

733+
def all_reduce_with_count(*args, **kwargs):
734+
nonlocal all_reduce_count
735+
all_reduce_count += 1
736+
return orig_all_reduce(*args, **kwargs)
737+
727738
torch.manual_seed(1) # same on all ranks
728739
for iter_idx in range(5):
729740
with patch_all_gather(all_gather_with_count), patch_reduce_scatter(
730741
reduce_scatter_with_count
731-
):
742+
), patch_all_reduce(all_reduce_with_count):
732743
for microbatch_idx in range(num_microbatches):
733744
is_last_microbatch = microbatch_idx == num_microbatches - 1
734745
if mode == "all":
@@ -756,10 +767,7 @@ def reduce_scatter_with_count(*args, **kwargs):
756767
* local_batch_size
757768
].detach()
758769
losses: List[torch.Tensor] = []
759-
for _model, _optim, inp in (
760-
(ref_model, ref_optim, global_inp),
761-
(model, optim, local_inp),
762-
):
770+
for _model, inp in ((ref_model, global_inp), (model, local_inp)):
763771
losses.append(_model(inp).sum())
764772
losses[-1].backward()
765773
dist.all_reduce(losses[1]) # partial -> replicated
@@ -778,7 +786,13 @@ def reduce_scatter_with_count(*args, **kwargs):
778786
# Expect additional reduce-scatters for all MLPs
779787
expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1)
780788
self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
781-
reduce_scatter_count = 0
789+
# Exclude the loss all-reduce per microbatch in our training loop
790+
all_reduce_count -= num_microbatches
791+
if mesh.ndim == 2:
792+
self.assertEqual(all_reduce_count, expected_reduce_scatter_count)
793+
else:
794+
self.assertEqual(all_reduce_count, 0)
795+
reduce_scatter_count = all_reduce_count = 0
782796

783797
# Expect one all-gather per MLP plus one for the root's linear in
784798
# the first microbatch's forward
@@ -872,8 +886,7 @@ def _test_1f1b_microbatching(
872886
ref_losses.append(ref_model(inp).sum())
873887
ref_losses[-1].backward()
874888
for param in ref_model.parameters():
875-
dist.all_reduce(param.grad)
876-
param.grad.detach().div_(self.world_size)
889+
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
877890

878891
for loss, ref_loss in zip(losses, ref_losses):
879892
self.assertEqual(loss, ref_loss)

torch/distributed/_composable/fsdp/_fsdp_param_group.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _init_mp_dtypes(self) -> None:
167167
def _init_grad_divide_factors(self):
168168
data_parallel_world_size = 1
169169
data_parallel_world_size *= self.mesh_info.shard_mesh_size
170-
if isinstance(self.mesh_info, HSDPMeshInfo):
170+
if self._is_hsdp:
171171
data_parallel_world_size *= self.mesh_info.replicate_mesh_size
172172
if self._reduce_dtype in (torch.float32, torch.bfloat16):
173173
# Use NCCL's AVG op to divide after reduction since it is more
@@ -348,7 +348,7 @@ def post_backward(self, *unused: Any):
348348
self.device,
349349
self._grad_divide_factors,
350350
self._all_reduce_process_group
351-
if self._should_all_reduce_grads()
351+
if self._is_hsdp and self.all_reduce_grads
352352
else None,
353353
self.comm_ctx.all_reduce_stream,
354354
)
@@ -481,6 +481,10 @@ def _use_post_forward_mesh(self) -> bool:
481481
and self.mesh_info != self.post_forward_mesh_info
482482
)
483483

484+
@property
485+
def _is_hsdp(self) -> bool:
486+
return isinstance(self.mesh_info, HSDPMeshInfo)
487+
484488
@property
485489
def _all_gather_process_group(self) -> dist.ProcessGroup:
486490
mesh_info = (
@@ -493,18 +497,13 @@ def _all_gather_process_group(self) -> dist.ProcessGroup:
493497

494498
@property
495499
def _reduce_scatter_process_group(self) -> dist.ProcessGroup:
496-
mesh_info = self.mesh_info
497-
assert isinstance(mesh_info, FSDPMeshInfo)
498-
return mesh_info.shard_process_group
500+
assert isinstance(self.mesh_info, FSDPMeshInfo)
501+
return self.mesh_info.shard_process_group
499502

500503
@property
501504
def _all_reduce_process_group(self) -> dist.ProcessGroup:
502-
mesh_info = self.mesh_info
503-
assert isinstance(mesh_info, HSDPMeshInfo)
504-
return mesh_info.replicate_process_group
505-
506-
def _should_all_reduce_grads(self) -> bool:
507-
return isinstance(self.mesh_info, HSDPMeshInfo) and self.all_reduce_grads
505+
assert isinstance(self.mesh_info, HSDPMeshInfo)
506+
return self.mesh_info.replicate_process_group
508507

509508

510509
def _get_param_module_infos(

torch/testing/_internal/common_fsdp.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,18 @@ def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
907907
dist.reduce_scatter_tensor = orig_reduce_scatter
908908

909909

910+
@contextlib.contextmanager
911+
def patch_all_reduce(new_all_reduce: Callable):
912+
orig_all_reduce = dist.all_reduce
913+
dist.barrier()
914+
dist.all_reduce = new_all_reduce
915+
try:
916+
yield
917+
finally:
918+
dist.barrier()
919+
dist.all_reduce = orig_all_reduce
920+
921+
910922
@no_type_check
911923
@contextlib.contextmanager
912924
def patch_unshard(new_unshard: Callable):

0 commit comments

Comments
 (0)