4141 FSDPTestMultiThread ,
4242 MLP ,
4343 patch_all_gather ,
44+ patch_all_reduce ,
4445 patch_reduce_scatter ,
4546 test_compiled_fsdp ,
4647)
@@ -649,16 +650,21 @@ def _test_train_shared_params(
649650class TestFullyShardGradientAccumulation (FSDPTest ):
650651 @property
651652 def world_size (self ) -> int :
652- return min (2 , torch .cuda .device_count ())
653+ return min (4 , torch .cuda .device_count ())
653654
654655 @skip_if_lt_x_gpu (2 )
655656 def test_gradient_accumulation (self ):
656657 """
657658 Tests gradient accumulation with/without gradient reduction and
658659 with/without resharding after backward.
659660 """
661+ meshes = [init_device_mesh ("cuda" , (self .world_size ,))] # always test FSDP
662+ if self .world_size == 4 : # test HSDP too if enough GPUs
663+ shard_size , replicate_size = 2 , 2
664+ meshes .append (init_device_mesh ("cuda" , (replicate_size , shard_size )))
660665 self .run_subtests (
661666 {
667+ "mesh" : meshes ,
662668 "reshard_after_forward" : [True , False , 2 ],
663669 # "all": disable reduce-scatter for all modules
664670 # "root_only": disable reduce-scatter for root's linear only
@@ -672,6 +678,7 @@ def test_gradient_accumulation(self):
672678
673679 def _test_gradient_accumulation (
674680 self ,
681+ mesh : DeviceMesh ,
675682 reshard_after_forward : Union [bool , int ],
676683 mode : str ,
677684 reshard_after_backward : bool ,
@@ -691,15 +698,13 @@ def _test_gradient_accumulation(
691698 global_batch_size = local_batch_size * self .world_size
692699 if mode == "some_mlps" :
693700 num_mlps_to_disable_reduce_scatter = 2
694- model = nn .Sequential (
695- * (
696- [nn .Linear (lin_dim , lin_dim )]
697- + [MLP (lin_dim , torch .device ("cpu" )) for _ in range (num_mlps )]
698- )
699- )
701+ modules = [nn .Linear (lin_dim , lin_dim )]
702+ modules .extend (MLP (lin_dim ) for _ in range (num_mlps ))
703+ model = nn .Sequential (* modules )
700704 ref_model = copy .deepcopy (model ).cuda ()
701705 fully_shard_fn = functools .partial (
702706 fully_shard ,
707+ mesh = mesh ,
703708 reshard_after_forward = reshard_after_forward ,
704709 offload_policy = offload_policy ,
705710 )
@@ -709,10 +714,11 @@ def _test_gradient_accumulation(
709714 ref_optim = torch .optim .Adam (ref_model .parameters (), lr = 1e-2 )
710715 optim = torch .optim .Adam (model .parameters (), lr = 1e-2 )
711716
717+ # TODO: Migrate to `CommDebugMode` once it supports c10d collectives.
712718 orig_all_gather = dist .all_gather_into_tensor
713- all_gather_count = 0
714719 orig_reduce_scatter = dist .reduce_scatter_tensor
715- reduce_scatter_count = 0
720+ orig_all_reduce = dist .all_reduce
721+ all_gather_count , reduce_scatter_count , all_reduce_count = 0 , 0 , 0
716722
717723 def all_gather_with_count (* args , ** kwargs ):
718724 nonlocal all_gather_count
@@ -724,11 +730,16 @@ def reduce_scatter_with_count(*args, **kwargs):
724730 reduce_scatter_count += 1
725731 return orig_reduce_scatter (* args , ** kwargs )
726732
733+ def all_reduce_with_count (* args , ** kwargs ):
734+ nonlocal all_reduce_count
735+ all_reduce_count += 1
736+ return orig_all_reduce (* args , ** kwargs )
737+
727738 torch .manual_seed (1 ) # same on all ranks
728739 for iter_idx in range (5 ):
729740 with patch_all_gather (all_gather_with_count ), patch_reduce_scatter (
730741 reduce_scatter_with_count
731- ):
742+ ), patch_all_reduce ( all_reduce_with_count ) :
732743 for microbatch_idx in range (num_microbatches ):
733744 is_last_microbatch = microbatch_idx == num_microbatches - 1
734745 if mode == "all" :
@@ -756,10 +767,7 @@ def reduce_scatter_with_count(*args, **kwargs):
756767 * local_batch_size
757768 ].detach ()
758769 losses : List [torch .Tensor ] = []
759- for _model , _optim , inp in (
760- (ref_model , ref_optim , global_inp ),
761- (model , optim , local_inp ),
762- ):
770+ for _model , inp in ((ref_model , global_inp ), (model , local_inp )):
763771 losses .append (_model (inp ).sum ())
764772 losses [- 1 ].backward ()
765773 dist .all_reduce (losses [1 ]) # partial -> replicated
@@ -778,7 +786,13 @@ def reduce_scatter_with_count(*args, **kwargs):
778786 # Expect additional reduce-scatters for all MLPs
779787 expected_reduce_scatter_count += (num_mlps ) * (num_microbatches - 1 )
780788 self .assertEqual (reduce_scatter_count , expected_reduce_scatter_count )
781- reduce_scatter_count = 0
789+ # Exclude the loss all-reduce per microbatch in our training loop
790+ all_reduce_count -= num_microbatches
791+ if mesh .ndim == 2 :
792+ self .assertEqual (all_reduce_count , expected_reduce_scatter_count )
793+ else :
794+ self .assertEqual (all_reduce_count , 0 )
795+ reduce_scatter_count = all_reduce_count = 0
782796
783797 # Expect one all-gather per MLP plus one for the root's linear in
784798 # the first microbatch's forward
@@ -872,8 +886,7 @@ def _test_1f1b_microbatching(
872886 ref_losses .append (ref_model (inp ).sum ())
873887 ref_losses [- 1 ].backward ()
874888 for param in ref_model .parameters ():
875- dist .all_reduce (param .grad )
876- param .grad .detach ().div_ (self .world_size )
889+ dist .all_reduce (param .grad , op = dist .ReduceOp .AVG )
877890
878891 for loss , ref_loss in zip (losses , ref_losses ):
879892 self .assertEqual (loss , ref_loss )
0 commit comments