@@ -738,51 +738,54 @@ def test_hsdp_broadcast_across_replicas(self):
738738 )
739739 model_args = ModelArgs ()
740740 model = Transformer (model_args )
741+ # Add a buffer to show that this flow works for buffers too
742+ model .register_buffer ("buf" , torch .randn ((model_args .dim ,)))
741743 for module in model .modules ():
742744 if isinstance (module , TransformerBlock ):
743745 fully_shard (module , mesh = mesh )
744746 fully_shard (model , mesh = mesh )
745747
746- # Only preserve the model parameters on the replicate mesh's rank 0
748+ # Only preserve the model states on the replicate mesh's rank 0
747749 if mesh .get_local_rank ("replicate" ) > 0 :
748- print (f"[Rank { self .rank } ] filling with 1337" )
749- for param in model .parameters ():
750- param .detach ().fill_ (1337 )
750+ for tensor in itertools .chain (model .parameters (), model .buffers ()):
751+ tensor .detach ().fill_ (1337 )
751752
752753 # Check that replicas are different
753- for param_name , param in model .named_parameters ( ):
754- local_param = param .to_local ()
755- local_param_list = [
756- torch .empty_like (local_param ) for _ in range (mesh ["replicate" ].size ())
754+ for tensor in itertools . chain ( model .parameters (), model . buffers () ):
755+ local_tensor = tensor .to_local () if isinstance ( tensor , DTensor ) else tensor
756+ local_tensor_list = [
757+ torch .empty_like (local_tensor ) for _ in range (mesh ["replicate" ].size ())
757758 ]
758759 dist .all_gather (
759- local_param_list , local_param , group = mesh .get_group ("replicate" )
760+ local_tensor_list , local_tensor , group = mesh .get_group ("replicate" )
760761 )
761- for other_local_param in local_param_list [1 :]:
762- self .assertEqual (other_local_param .shape , local_param_list [0 ].shape )
763- self .assertNotEqual (other_local_param , local_param_list [0 ])
762+ for other_local_tensor in local_tensor_list [1 :]:
763+ self .assertEqual (other_local_tensor .shape , local_tensor_list [0 ].shape )
764+ self .assertNotEqual (other_local_tensor , local_tensor_list [0 ])
764765
765766 # Broadcast from replicate mesh's rank 0
766767 replicate_group = mesh .get_group ("replicate" )
767- for param in model .parameters ():
768+ for tensor in itertools . chain ( model .parameters (), model . buffers () ):
768769 # E.g. for mesh [[0, 1, 2, 3], [4, 5, 6, 7]] sharding on dim-1 and
769770 # replicating on dim-0, broadcast with sources 0, 1, 2, 3
770771 src_rank = dist .get_process_group_ranks (replicate_group )[0 ]
771772 torch .distributed .broadcast (
772- param .to_local (), src = src_rank , group = replicate_group
773+ tensor .to_local () if isinstance (tensor , DTensor ) else tensor ,
774+ src = src_rank ,
775+ group = replicate_group ,
773776 )
774777
775778 # Check that replicas are the same
776- for param_name , param in model .named_parameters ( ):
777- local_param = param .to_local ()
778- local_param_list = [
779- torch .empty_like (local_param ) for _ in range (mesh ["replicate" ].size ())
779+ for tensor in itertools . chain ( model .parameters (), model . buffers () ):
780+ local_tensor = tensor .to_local () if isinstance ( tensor , DTensor ) else tensor
781+ local_tensor_list = [
782+ torch .empty_like (local_tensor ) for _ in range (mesh ["replicate" ].size ())
780783 ]
781784 dist .all_gather (
782- local_param_list , local_param , group = mesh .get_group ("replicate" )
785+ local_tensor_list , local_tensor , group = mesh .get_group ("replicate" )
783786 )
784- for other_local_param in local_param_list [1 :]:
785- self .assertEqual (other_local_param , local_param_list [0 ])
787+ for other_local_tensor in local_tensor_list [1 :]:
788+ self .assertEqual (other_local_tensor , local_tensor_list [0 ])
786789
787790 # Check that we can run an iteration without erroring
788791 inp = torch .randint (0 , model_args .vocab_size , (2 , 16 ), device = "cuda" )
0 commit comments