4242 FSDPTestMultiThread ,
4343 MLP ,
4444 patch_all_gather ,
45+ patch_all_reduce ,
4546 patch_reduce_scatter ,
4647 test_compiled_fsdp ,
4748)
@@ -650,16 +651,21 @@ def _test_train_shared_params(
650651class TestFullyShardGradientAccumulation (FSDPTest ):
651652 @property
652653 def world_size (self ) -> int :
653- return min (2 , torch .cuda .device_count ())
654+ return min (4 , torch .cuda .device_count ())
654655
655656 @skip_if_lt_x_gpu (2 )
656657 def test_gradient_accumulation (self ):
657658 """
658659 Tests gradient accumulation with/without gradient reduction and
659660 with/without resharding after backward.
660661 """
662+ meshes = [init_device_mesh ("cuda" , (self .world_size ,))] # always test FSDP
663+ if self .world_size == 4 : # test HSDP too if enough GPUs
664+ shard_size , replicate_size = 2 , 2
665+ meshes .append (init_device_mesh ("cuda" , (replicate_size , shard_size )))
661666 self .run_subtests (
662667 {
668+ "mesh" : meshes ,
663669 "reshard_after_forward" : [True , False , 2 ],
664670 # "all": disable reduce-scatter for all modules
665671 # "root_only": disable reduce-scatter for root's linear only
@@ -673,6 +679,7 @@ def test_gradient_accumulation(self):
673679
674680 def _test_gradient_accumulation (
675681 self ,
682+ mesh : DeviceMesh ,
676683 reshard_after_forward : Union [bool , int ],
677684 mode : str ,
678685 reshard_after_backward : bool ,
@@ -692,15 +699,13 @@ def _test_gradient_accumulation(
692699 global_batch_size = local_batch_size * self .world_size
693700 if mode == "some_mlps" :
694701 num_mlps_to_disable_reduce_scatter = 2
695- model = nn .Sequential (
696- * (
697- [nn .Linear (lin_dim , lin_dim )]
698- + [MLP (lin_dim , torch .device ("cpu" )) for _ in range (num_mlps )]
699- )
700- )
702+ modules = [nn .Linear (lin_dim , lin_dim )]
703+ modules .extend (MLP (lin_dim ) for _ in range (num_mlps ))
704+ model = nn .Sequential (* modules )
701705 ref_model = copy .deepcopy (model ).cuda ()
702706 fully_shard_fn = functools .partial (
703707 fully_shard ,
708+ mesh = mesh ,
704709 reshard_after_forward = reshard_after_forward ,
705710 offload_policy = offload_policy ,
706711 )
@@ -710,10 +715,11 @@ def _test_gradient_accumulation(
710715 ref_optim = torch .optim .Adam (ref_model .parameters (), lr = 1e-2 )
711716 optim = torch .optim .Adam (model .parameters (), lr = 1e-2 )
712717
718+ # TODO: Migrate to `CommDebugMode` once it supports c10d collectives.
713719 orig_all_gather = dist .all_gather_into_tensor
714- all_gather_count = 0
715720 orig_reduce_scatter = dist .reduce_scatter_tensor
716- reduce_scatter_count = 0
721+ orig_all_reduce = dist .all_reduce
722+ all_gather_count , reduce_scatter_count , all_reduce_count = 0 , 0 , 0
717723
718724 def all_gather_with_count (* args , ** kwargs ):
719725 nonlocal all_gather_count
@@ -725,11 +731,16 @@ def reduce_scatter_with_count(*args, **kwargs):
725731 reduce_scatter_count += 1
726732 return orig_reduce_scatter (* args , ** kwargs )
727733
734+ def all_reduce_with_count (* args , ** kwargs ):
735+ nonlocal all_reduce_count
736+ all_reduce_count += 1
737+ return orig_all_reduce (* args , ** kwargs )
738+
728739 torch .manual_seed (1 ) # same on all ranks
729740 for iter_idx in range (5 ):
730741 with patch_all_gather (all_gather_with_count ), patch_reduce_scatter (
731742 reduce_scatter_with_count
732- ):
743+ ), patch_all_reduce ( all_reduce_with_count ) :
733744 for microbatch_idx in range (num_microbatches ):
734745 is_last_microbatch = microbatch_idx == num_microbatches - 1
735746 if mode == "all" :
@@ -757,10 +768,7 @@ def reduce_scatter_with_count(*args, **kwargs):
757768 * local_batch_size
758769 ].detach ()
759770 losses : List [torch .Tensor ] = []
760- for _model , _optim , inp in (
761- (ref_model , ref_optim , global_inp ),
762- (model , optim , local_inp ),
763- ):
771+ for _model , inp in ((ref_model , global_inp ), (model , local_inp )):
764772 losses .append (_model (inp ).sum ())
765773 losses [- 1 ].backward ()
766774 dist .all_reduce (losses [1 ]) # partial -> replicated
@@ -779,7 +787,13 @@ def reduce_scatter_with_count(*args, **kwargs):
779787 # Expect additional reduce-scatters for all MLPs
780788 expected_reduce_scatter_count += (num_mlps ) * (num_microbatches - 1 )
781789 self .assertEqual (reduce_scatter_count , expected_reduce_scatter_count )
782- reduce_scatter_count = 0
790+ # Exclude the loss all-reduce per microbatch in our training loop
791+ all_reduce_count -= num_microbatches
792+ if mesh .ndim == 2 :
793+ self .assertEqual (all_reduce_count , expected_reduce_scatter_count )
794+ else :
795+ self .assertEqual (all_reduce_count , 0 )
796+ reduce_scatter_count = all_reduce_count = 0
783797
784798 # Expect one all-gather per MLP plus one for the root's linear in
785799 # the first microbatch's forward
@@ -873,8 +887,7 @@ def _test_1f1b_microbatching(
873887 ref_losses .append (ref_model (inp ).sum ())
874888 ref_losses [- 1 ].backward ()
875889 for param in ref_model .parameters ():
876- dist .all_reduce (param .grad )
877- param .grad .detach ().div_ (self .world_size )
890+ dist .all_reduce (param .grad , op = dist .ReduceOp .AVG )
878891
879892 for loss , ref_loss in zip (losses , ref_losses ):
880893 self .assertEqual (loss , ref_loss )
0 commit comments