@@ -725,5 +725,69 @@ def test_process_group_init(self):
725725 self .assertEqual (param .grad , ref_param .grad )
726726
727727
728+ class TestFullyShardHSDPBroadcast (FSDPTestMultiThread ):
729+ @property
730+ def world_size (self ) -> int :
731+ return 4
732+
733+ @unittest .skipIf (not TEST_CUDA , "no cuda" )
734+ def test_hsdp_broadcast_across_replicas (self ):
735+ shard_size , replicate_size = 2 , 2
736+ mesh = init_device_mesh (
737+ "cuda" , (replicate_size , shard_size ), mesh_dim_names = ("replicate" , "shard" )
738+ )
739+ model_args = ModelArgs ()
740+ model = Transformer (model_args )
741+ for module in model .modules ():
742+ if isinstance (module , TransformerBlock ):
743+ fully_shard (module , mesh = mesh )
744+ fully_shard (model , mesh = mesh )
745+
746+ # Only preserve the model parameters on the replicate mesh's rank 0
747+ if mesh .get_local_rank ("replicate" ) > 0 :
748+ print (f"[Rank { self .rank } ] filling with 1337" )
749+ for param in model .parameters ():
750+ param .detach ().fill_ (1337 )
751+
752+ # Check that replicas are different
753+ for param_name , param in model .named_parameters ():
754+ local_param = param .to_local ()
755+ local_param_list = [
756+ torch .empty_like (local_param ) for _ in range (mesh ["replicate" ].size ())
757+ ]
758+ dist .all_gather (
759+ local_param_list , local_param , group = mesh .get_group ("replicate" )
760+ )
761+ for other_local_param in local_param_list [1 :]:
762+ self .assertEqual (other_local_param .shape , local_param_list [0 ].shape )
763+ self .assertNotEqual (other_local_param , local_param_list [0 ])
764+
765+ # Broadcast from replicate mesh's rank 0
766+ replicate_group = mesh .get_group ("replicate" )
767+ for param in model .parameters ():
768+ # E.g. for mesh [[0, 1, 2, 3], [4, 5, 6, 7]] sharding on dim-1 and
769+ # replicating on dim-0, broadcast with sources 0, 1, 2, 3
770+ src_rank = dist .get_process_group_ranks (replicate_group )[0 ]
771+ torch .distributed .broadcast (
772+ param .to_local (), src = src_rank , group = replicate_group
773+ )
774+
775+ # Check that replicas are the same
776+ for param_name , param in model .named_parameters ():
777+ local_param = param .to_local ()
778+ local_param_list = [
779+ torch .empty_like (local_param ) for _ in range (mesh ["replicate" ].size ())
780+ ]
781+ dist .all_gather (
782+ local_param_list , local_param , group = mesh .get_group ("replicate" )
783+ )
784+ for other_local_param in local_param_list [1 :]:
785+ self .assertEqual (other_local_param , local_param_list [0 ])
786+
787+ # Check that we can run an iteration without erroring
788+ inp = torch .randint (0 , model_args .vocab_size , (2 , 16 ), device = "cuda" )
789+ model (inp ).sum ().backward ()
790+
791+
728792if __name__ == "__main__" :
729793 run_tests ()
0 commit comments