1717 DynamoDistributedMultiProcTestCase ,
1818 _dynamo_dist_per_rank_init ,
1919 requires_nccl ,
20+ run_with_legacy_funcol ,
21+ run_with_both_funcol_impls ,
22+ run_with_both_funcol_impls_with_arg ,
2023 skip_if_lt_x_gpu ,
2124)
22- from torch .testing ._internal .common_utils import requires_cuda
25+ from torch .testing ._internal .common_utils import instantiate_parametrized_tests , requires_cuda
2326from torch ._inductor .compile_fx import compile_fx as inductor_compile_fx
2427from torch .utils ._triton import has_triton
2528from torch ._inductor .utils import run_and_get_triton_code
@@ -53,6 +56,7 @@ def world_size(self) -> int:
5356 @skip_if_lt_x_gpu (2 )
5457 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
5558 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
59+ @run_with_legacy_funcol
5660 def test_broadcast_inductor (self ):
5761 """
5862 Testing if broadcast works correctly when using inductor
@@ -89,6 +93,7 @@ def compile(func, example_inputs):
8993 @skip_if_lt_x_gpu (2 )
9094 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
9195 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
96+ @run_with_legacy_funcol
9297 def test_allreduce_inductor (self ):
9398 """
9499 This is matmul/cat/allreduce is a pattern we aim to optimize.
@@ -131,6 +136,7 @@ def test_c10d_functional_tagged_pt2_compliant(self):
131136 @skip_if_lt_x_gpu (2 )
132137 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
133138 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
139+ @run_with_legacy_funcol
134140 def test_eager_allreduce_inductor_wait (self ):
135141
136142 def eager_func (a , b , c , d , * , tag , ranks , group_size ):
@@ -170,6 +176,7 @@ def compile(func, example_inputs):
170176 @skip_if_lt_x_gpu (2 )
171177 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
172178 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
179+ @run_with_legacy_funcol
173180 def test_inductor_allreduce_eager_wait (self ):
174181
175182 def inductor_func (a , b , c , d , * , tag , ranks , group_size ):
@@ -208,6 +215,7 @@ def compile(func, example_inputs):
208215 @patch .object (torch ._inductor .config , "allow_buffer_reuse" , True )
209216 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
210217 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
218+ @run_with_legacy_funcol
211219 def test_allreduce_input_buffer_reuse (self ):
212220 def func (a , * , tag , ranks , group_size ):
213221 ar = _functional_collectives .all_reduce (a , "sum" , ranks , tag )
@@ -227,6 +235,7 @@ def func(a, *, tag, ranks, group_size):
227235 @skip_if_lt_x_gpu (2 )
228236 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
229237 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
238+ @run_with_legacy_funcol
230239 def test_permute_tensor (self ):
231240 def func (tensor , src_dst_pairs , * , tag , ranks , group_size ):
232241 return _functional_collectives .permute_tensor (tensor , src_dst_pairs , ranks , tag )
@@ -256,6 +265,7 @@ def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
256265 @patch .object (torch ._inductor .config , "allow_buffer_reuse" , True )
257266 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
258267 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
268+ @run_with_legacy_funcol
259269 def test_allgather_output_buffer_reuse (self ):
260270 class Model (torch .nn .Module ):
261271 def __init__ (self , * args , ** kwargs ) -> None :
@@ -281,6 +291,7 @@ def forward(self, x, world_size, tag, ranks, group_size):
281291 @skip_if_lt_x_gpu (2 )
282292 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
283293 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
294+ @run_with_legacy_funcol
284295 def test_allgather_contiguous_input (self ):
285296 class Model (torch .nn .Module ):
286297 def __init__ (self , * args , ** kwargs ) -> None :
@@ -307,6 +318,7 @@ def forward(self, x, world_size, tag, ranks, group_size):
307318 @skip_if_lt_x_gpu (2 )
308319 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
309320 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
321+ @run_with_legacy_funcol
310322 def test_allgather_into_tensor_inductor (self ):
311323 """
312324 This is matmul/cat/allreduce is a pattern we aim to optimize.
@@ -339,6 +351,7 @@ def compile(func, example_inputs):
339351 @skip_if_lt_x_gpu (2 )
340352 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
341353 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
354+ @run_with_legacy_funcol
342355 def test_reduce_scatter_tensor_inductor (self ):
343356 def example (a , b , * , tag , ranks , group_size ):
344357 c = torch .matmul (a , b )
@@ -369,6 +382,7 @@ def compile(func, example_inputs):
369382 @patch .object (torch ._dynamo .config , "capture_scalar_outputs" , True )
370383 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
371384 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
385+ @run_with_legacy_funcol
372386 def test_all_to_all_single_inductor (self ):
373387 def example (inp , input_split_sizes_tensor , output_split_sizes_tensor , * , tag , ranks , group_size ):
374388 input_split_sizes = _tolist_with_constrain_as_size (input_split_sizes_tensor )
@@ -454,6 +468,7 @@ def example(inp, input_split_sizes_tensor, *, tag, ranks, group_size):
454468 @patch .object (torch ._dynamo .config , "capture_scalar_outputs" , True )
455469 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
456470 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
471+ @run_with_legacy_funcol
457472 def test_all_to_all_single_inductor_input_split_sizes_none (self ):
458473 def example (inp , output_split_sizes_tensor , * , tag , ranks , group_size ):
459474 output_split_sizes = _tolist_with_constrain_as_size (output_split_sizes_tensor )
@@ -495,6 +510,7 @@ def example(inp, output_split_sizes_tensor, *, tag, ranks, group_size):
495510 @skip_if_lt_x_gpu (2 )
496511 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
497512 @patch .object (torch ._inductor .config , "compile_threads" , 1 )
513+ @run_with_legacy_funcol
498514 def test_all_to_all_single_inductor_split_sizes_none (self ):
499515 def example (inp , * , tag , ranks , group_size ):
500516 a2a = torch .ops .c10d_functional .all_to_all_single (
@@ -524,6 +540,7 @@ def example(inp, *, tag, ranks, group_size):
524540 self .assertTrue (same (eager_out , inductor_out , tol = 0.001 ))
525541
526542
543+ @instantiate_parametrized_tests
527544@requires_nccl ()
528545@requires_cuda
529546class TestCollectivesInductor (DynamoDistributedSingleProcTestCase ):
@@ -539,6 +556,7 @@ def get_world_trs(self, world_size=1):
539556
540557 @unittest .skipIf (not has_triton (), "Inductor+gpu needs triton and recent GPU arch" )
541558 @torch ._inductor .config .patch (debug = True )
559+ @run_with_legacy_funcol # impl specific
542560 def test_inductor_single_op (self ):
543561
544562 def func (inp , * , tag , ranks , group_size ):
@@ -567,6 +585,7 @@ def func(inp, *, tag, ranks, group_size):
567585
568586 @unittest .skipIf (not has_triton (), "Inductor+gpu needs triton and recent GPU arch" )
569587 @torch ._inductor .config .patch (debug = True )
588+ @run_with_legacy_funcol # impl specific
570589 def test_inductor_steal_buffer (self ):
571590 """
572591 it's ok and optimal if inductor allreduce mutates the buffer of an intermediate
@@ -604,6 +623,7 @@ def func(inp, *, tag, ranks, group_size):
604623
605624 @unittest .skipIf (not has_triton (), "Inductor+gpu needs triton and recent GPU arch" )
606625 @torch ._inductor .config .patch ({"debug" : True , "triton.descriptive_names" : False })
626+ @run_with_legacy_funcol # impl specific
607627 def test_inductor_doesnt_mutate_shared (self ):
608628 """
609629 make sure that an intermediate that's going to be reuse isn't mutated unless copied
@@ -641,40 +661,49 @@ def func(inp, *, tag, ranks, group_size):
641661 correct = func (inputs , ** self .get_world_trs ())
642662 self .assertTrue (same (out , correct ))
643663
644- def test_dynamo_trace_allreduce (self ):
664+ @run_with_both_funcol_impls_with_arg
665+ def test_dynamo_trace_allreduce (self , use_native_funcol ):
645666
646- def func (inp , * , tag , ranks , group_size ):
647- ar = _functional_collectives .all_reduce (inp , "sum" , ranks , tag )
667+ def func (inp ):
668+ if use_native_funcol :
669+ ar = _functional_collectives .all_reduce (inp , "sum" , "0" )
670+ else :
671+ ar = _functional_collectives .all_reduce (inp , "sum" , [0 ], "" )
648672 return ar
649673
650674 inputs = torch .ones (4 , 4 , device = "cuda" )
651675 counter = CompileCounter ()
652676 compiled = torch .compile (func , backend = counter )
653- out = compiled (inputs , ** self . get_world_trs () )
654- correct = func (inputs , ** self . get_world_trs () )
677+ out = compiled (inputs )
678+ correct = func (inputs )
655679 self .assertEqual (counter .frame_count , 1 )
656680
657681 # should test more precisely, but the 2 is supposed to be (all_reduce, wait)
658682 self .assertEqual (counter .op_count , 2 )
659683 self .assertTrue (same (out , correct ))
660684
661- def test_dynamo_trace_all_gather_tensor (self ):
685+ @run_with_both_funcol_impls_with_arg
686+ def test_dynamo_trace_all_gather_tensor (self , use_native_funcol ):
662687
663- def func (inp , * , tag , ranks , group_size ):
664- ar = _functional_collectives .all_gather_tensor (inp , 0 , ranks , tag )
688+ def func (inp ):
689+ if use_native_funcol :
690+ ar = _functional_collectives .all_gather_tensor (inp , 0 , "0" )
691+ else :
692+ ar = _functional_collectives .all_gather_tensor (inp , 0 , [0 ], "" )
665693 return ar
666694
667695 inputs = torch .ones (4 , 4 , device = "cuda" )
668696 counter = CompileCounter ()
669697 compiled = torch .compile (func , backend = counter )
670- out = compiled (inputs , ** self . get_world_trs () )
671- correct = func (inputs , ** self . get_world_trs () )
698+ out = compiled (inputs )
699+ correct = func (inputs )
672700 self .assertEqual (counter .frame_count , 1 )
673701
674702 # should test more precisely, but the 2 is supposed to be (all_gather, wait)
675703 self .assertEqual (counter .op_count , 2 )
676704 self .assertTrue (same (out , correct ))
677705
706+ @run_with_both_funcol_impls
678707 def test_dynamo_trace_all_gather_tensor_pg (self ):
679708
680709 def func (inp , * , pg ):
@@ -692,6 +721,7 @@ def func(inp, *, pg):
692721 self .assertEqual (counter .op_count , 2 )
693722 self .assertTrue (same (out , correct ))
694723
724+ @run_with_both_funcol_impls
695725 def test_dynamo_rewrite_dist_all_gather (self ):
696726
697727 def func (inp , out , * , pg ):
@@ -717,6 +747,7 @@ def func(inp, out, *, pg):
717747 assert counter .op_count == 3
718748 assert same (outputs , correct_outputs )
719749
750+ @run_with_both_funcol_impls
720751 def test_dynamo_rewrite_dist_all_gather_list (self ):
721752
722753 def func (inp , out , * , pg ):
@@ -739,6 +770,7 @@ def func(inp, out, *, pg):
739770 assert counter .frame_count == 1
740771 assert same (outputs , correct_outputs )
741772
773+ @run_with_both_funcol_impls
742774 def test_dynamo_rewrite_dist_all_gather_args_match (self ):
743775 # Duplicated most of the structure from test_dynamo_rewrite_dist_all_gather
744776 # except uses kwargs to ensure rewrite has matching arg names
@@ -766,6 +798,7 @@ def func(inp, out, *, pg):
766798 assert counter .op_count == 3
767799 assert same (outputs , correct_outputs )
768800
801+ @run_with_both_funcol_impls
769802 def test_dynamo_rewrite_dist_reduce_scatter (self ):
770803
771804 def func (inp , out , * , pg ):
@@ -791,6 +824,7 @@ def func(inp, out, *, pg):
791824 assert counter .op_count == 3
792825 assert same (outputs , correct_outputs )
793826
827+ @run_with_both_funcol_impls
794828 def test_dynamo_rewrite_dist_allreduce (self ):
795829
796830 def func (tensor , pg ):
@@ -813,6 +847,7 @@ def func(tensor, pg):
813847 assert counter .op_count == 3
814848 assert same (inputs_compiled , inputs_eager )
815849
850+ @run_with_both_funcol_impls
816851 def test_dynamo_rewrite_dist_all_to_all_single (self ):
817852
818853 def func (output , input , pg ):
@@ -836,6 +871,7 @@ def func(output, input, pg):
836871 assert counter .frame_count == 1
837872 assert same (output_compiled , output_eager )
838873
874+ @run_with_both_funcol_impls
839875 def test_dynamo_support_collective_op_with_async_op_False (self ):
840876
841877 def func (inp , out , * , pg ):
@@ -862,6 +898,7 @@ def func(inp, out, *, pg):
862898 assert counter .op_count == 3
863899 assert same (outputs , correct_outputs )
864900
901+ @run_with_both_funcol_impls
865902 def test_dynamo_graphbreaks_unsupported_async_op (self ):
866903
867904 def func (inp , out , * , pg ):
@@ -887,6 +924,7 @@ def func(inp, out, *, pg):
887924 assert counter .op_count == 0
888925 assert same (outputs , correct_outputs )
889926
927+ @run_with_both_funcol_impls
890928 def test_dynamo_pg_var (self ):
891929 def func (inp , * , pg ):
892930 x = pg .rank () + 1 % pg .size ()
@@ -903,23 +941,28 @@ def func(inp, *, pg):
903941 assert counter .op_count == 1
904942 assert same (outputs , correct_outputs )
905943
906- def test_dynamo_trace_reduce_scatter_tensor (self ):
944+ @run_with_both_funcol_impls_with_arg
945+ def test_dynamo_trace_reduce_scatter_tensor (self , use_native_funcol ):
907946
908- def func (inp , * , tag , ranks , group_size ):
909- ar = _functional_collectives .reduce_scatter_tensor (inp , "sum" , 0 , ranks , tag )
947+ def func (inp ):
948+ if use_native_funcol :
949+ ar = _functional_collectives .reduce_scatter_tensor (inp , "sum" , 0 , "0" )
950+ else :
951+ ar = _functional_collectives .reduce_scatter_tensor (inp , "sum" , 0 , [0 ], "" )
910952 return ar
911953
912954 inputs = torch .ones (4 , 4 , device = "cuda" )
913955 counter = CompileCounter ()
914956 compiled = torch .compile (func , backend = counter )
915- out = compiled (inputs , ** self . get_world_trs () )
916- correct = func (inputs , ** self . get_world_trs () )
957+ out = compiled (inputs )
958+ correct = func (inputs )
917959 self .assertEqual (counter .frame_count , 1 )
918960
919961 # should test more precisely, but the 2 is supposed to be (reduce_scatter, wait)
920962 self .assertEqual (counter .op_count , 2 )
921963 self .assertTrue (same (out , correct ))
922964
965+ @run_with_both_funcol_impls
923966 def test_dynamo_trace_allgather_coalesced (self ):
924967 def func (inp , * , tag , ranks , group_size ):
925968 ar = torch .ops .c10d_functional .all_gather_into_tensor_coalesced (inp , tag , ranks , group_size )
@@ -935,25 +978,29 @@ def func(inp, *, tag, ranks, group_size):
935978 assert same (out , correct )
936979
937980
938- def test_backwards (self ):
981+ @run_with_both_funcol_impls_with_arg
982+ def test_backwards (self , use_native_funcol ):
939983 """
940984 It's probably not that common to need backwards support for collectives.
941985
942986 However, I wanted to at least see if it was possible to support it as a design goal.
943987 """
944- def func (inp , * , tag , ranks , group_size ):
945- ar = _functional_collectives .all_reduce (inp , "sum" , ranks , tag )
988+ def func (inp ):
989+ if use_native_funcol :
990+ ar = _functional_collectives .all_reduce (inp , "sum" , "0" )
991+ else :
992+ ar = _functional_collectives .all_reduce (inp , "sum" , [0 ], "" )
946993 return ar
947994
948995 input = torch .ones (4 , 4 , device = "cuda" , requires_grad = True )
949996 # TODO implement backwards
950997 with self .assertRaisesRegex (RuntimeError , "element 0 of tensors does not require grad and does not have a grad_fn" ):
951998 compiled = torch .compile (func , backend = "aot_eager" ) # inductor bug with single-op allreduce graph
952- out = compiled (input , ** self . get_world_trs () )
999+ out = compiled (input )
9531000 out .sum ().backward ()
9541001
9551002 correct_input = input .clone ().detach ().requires_grad_ ()
956- correct = func (correct_input , ** self . get_world_trs () )
1003+ correct = func (correct_input )
9571004 correct .sum ().backward ()
9581005 self .assertTrue (same (out , correct ))
9591006 self .assertTrue (same (input .grad , correct_input .grad ))
@@ -965,6 +1012,7 @@ def test_meta(self):
9651012
9661013 @unittest .skipIf (not has_triton (), "Inductor+gpu needs triton and recent GPU arch" )
9671014 @torch ._inductor .config .patch ({"debug" : True , "triton.descriptive_names" : False })
1015+ @run_with_legacy_funcol # impl specific
9681016 def test_inductor_all_gather_coalesced (self ):
9691017 """
9701018 make sure that an intermediate that's going to be reuse isn't mutated unless copied
@@ -1011,6 +1059,7 @@ def func(inp, *, tag, ranks, group_size):
10111059
10121060 @unittest .skipIf (not has_triton (), "Inductor+gpu needs triton and recent GPU arch" )
10131061 @torch ._inductor .config .patch ({"debug" : True , "triton.descriptive_names" : False })
1062+ @run_with_legacy_funcol # impl specific
10141063 def test_inductor_reduce_scatter_coalesced (self ):
10151064 """
10161065 make sure that an intermediate that's going to be reuse isn't mutated unless copied
0 commit comments