Skip to content

Commit dd6b5e2

Browse files
Yifu Wangpytorchmergebot
authored andcommitted
Prepare test_inductor_collectives.py for native funcol migration (#120025)
There are some tests in this file that are impl specific, e.g. verifying generated code via `FileCheck`. These tests are covered for native funcol in test_c10d_functional_native.py, therefore marking them with `@run_with_legacy_funcol`. Other tests are marked with `@run_with_both_funcol_impls`. Pull Request resolved: #120025 Approved by: https://github.com/wanchaol ghstack dependencies: #119982
1 parent af765db commit dd6b5e2

File tree

2 files changed

+77
-21
lines changed

2 files changed

+77
-21
lines changed

test/distributed/test_inductor_collectives.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
DynamoDistributedMultiProcTestCase,
1818
_dynamo_dist_per_rank_init,
1919
requires_nccl,
20+
run_with_legacy_funcol,
21+
run_with_both_funcol_impls,
22+
run_with_both_funcol_impls_with_arg,
2023
skip_if_lt_x_gpu,
2124
)
22-
from torch.testing._internal.common_utils import requires_cuda
25+
from torch.testing._internal.common_utils import instantiate_parametrized_tests, requires_cuda
2326
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
2427
from torch.utils._triton import has_triton
2528
from torch._inductor.utils import run_and_get_triton_code
@@ -53,6 +56,7 @@ def world_size(self) -> int:
5356
@skip_if_lt_x_gpu(2)
5457
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
5558
@patch.object(torch._inductor.config, "compile_threads", 1)
59+
@run_with_legacy_funcol
5660
def test_broadcast_inductor(self):
5761
"""
5862
Testing if broadcast works correctly when using inductor
@@ -89,6 +93,7 @@ def compile(func, example_inputs):
8993
@skip_if_lt_x_gpu(2)
9094
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
9195
@patch.object(torch._inductor.config, "compile_threads", 1)
96+
@run_with_legacy_funcol
9297
def test_allreduce_inductor(self):
9398
"""
9499
This is matmul/cat/allreduce is a pattern we aim to optimize.
@@ -131,6 +136,7 @@ def test_c10d_functional_tagged_pt2_compliant(self):
131136
@skip_if_lt_x_gpu(2)
132137
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
133138
@patch.object(torch._inductor.config, "compile_threads", 1)
139+
@run_with_legacy_funcol
134140
def test_eager_allreduce_inductor_wait(self):
135141

136142
def eager_func(a, b, c, d, *, tag, ranks, group_size):
@@ -170,6 +176,7 @@ def compile(func, example_inputs):
170176
@skip_if_lt_x_gpu(2)
171177
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
172178
@patch.object(torch._inductor.config, "compile_threads", 1)
179+
@run_with_legacy_funcol
173180
def test_inductor_allreduce_eager_wait(self):
174181

175182
def inductor_func(a, b, c, d, *, tag, ranks, group_size):
@@ -208,6 +215,7 @@ def compile(func, example_inputs):
208215
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
209216
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
210217
@patch.object(torch._inductor.config, "compile_threads", 1)
218+
@run_with_legacy_funcol
211219
def test_allreduce_input_buffer_reuse(self):
212220
def func(a, *, tag, ranks, group_size):
213221
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
@@ -227,6 +235,7 @@ def func(a, *, tag, ranks, group_size):
227235
@skip_if_lt_x_gpu(2)
228236
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
229237
@patch.object(torch._inductor.config, "compile_threads", 1)
238+
@run_with_legacy_funcol
230239
def test_permute_tensor(self):
231240
def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
232241
return _functional_collectives.permute_tensor(tensor, src_dst_pairs, ranks, tag)
@@ -256,6 +265,7 @@ def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
256265
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
257266
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
258267
@patch.object(torch._inductor.config, "compile_threads", 1)
268+
@run_with_legacy_funcol
259269
def test_allgather_output_buffer_reuse(self):
260270
class Model(torch.nn.Module):
261271
def __init__(self, *args, **kwargs) -> None:
@@ -281,6 +291,7 @@ def forward(self, x, world_size, tag, ranks, group_size):
281291
@skip_if_lt_x_gpu(2)
282292
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
283293
@patch.object(torch._inductor.config, "compile_threads", 1)
294+
@run_with_legacy_funcol
284295
def test_allgather_contiguous_input(self):
285296
class Model(torch.nn.Module):
286297
def __init__(self, *args, **kwargs) -> None:
@@ -307,6 +318,7 @@ def forward(self, x, world_size, tag, ranks, group_size):
307318
@skip_if_lt_x_gpu(2)
308319
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
309320
@patch.object(torch._inductor.config, "compile_threads", 1)
321+
@run_with_legacy_funcol
310322
def test_allgather_into_tensor_inductor(self):
311323
"""
312324
This is matmul/cat/allreduce is a pattern we aim to optimize.
@@ -339,6 +351,7 @@ def compile(func, example_inputs):
339351
@skip_if_lt_x_gpu(2)
340352
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
341353
@patch.object(torch._inductor.config, "compile_threads", 1)
354+
@run_with_legacy_funcol
342355
def test_reduce_scatter_tensor_inductor(self):
343356
def example(a, b, *, tag, ranks, group_size):
344357
c = torch.matmul(a, b)
@@ -369,6 +382,7 @@ def compile(func, example_inputs):
369382
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
370383
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
371384
@patch.object(torch._inductor.config, "compile_threads", 1)
385+
@run_with_legacy_funcol
372386
def test_all_to_all_single_inductor(self):
373387
def example(inp, input_split_sizes_tensor, output_split_sizes_tensor, *, tag, ranks, group_size):
374388
input_split_sizes = _tolist_with_constrain_as_size(input_split_sizes_tensor)
@@ -454,6 +468,7 @@ def example(inp, input_split_sizes_tensor, *, tag, ranks, group_size):
454468
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
455469
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
456470
@patch.object(torch._inductor.config, "compile_threads", 1)
471+
@run_with_legacy_funcol
457472
def test_all_to_all_single_inductor_input_split_sizes_none(self):
458473
def example(inp, output_split_sizes_tensor, *, tag, ranks, group_size):
459474
output_split_sizes = _tolist_with_constrain_as_size(output_split_sizes_tensor)
@@ -495,6 +510,7 @@ def example(inp, output_split_sizes_tensor, *, tag, ranks, group_size):
495510
@skip_if_lt_x_gpu(2)
496511
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
497512
@patch.object(torch._inductor.config, "compile_threads", 1)
513+
@run_with_legacy_funcol
498514
def test_all_to_all_single_inductor_split_sizes_none(self):
499515
def example(inp, *, tag, ranks, group_size):
500516
a2a = torch.ops.c10d_functional.all_to_all_single(
@@ -524,6 +540,7 @@ def example(inp, *, tag, ranks, group_size):
524540
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
525541

526542

543+
@instantiate_parametrized_tests
527544
@requires_nccl()
528545
@requires_cuda
529546
class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
@@ -539,6 +556,7 @@ def get_world_trs(self, world_size=1):
539556

540557
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
541558
@torch._inductor.config.patch(debug=True)
559+
@run_with_legacy_funcol # impl specific
542560
def test_inductor_single_op(self):
543561

544562
def func(inp, *, tag, ranks, group_size):
@@ -567,6 +585,7 @@ def func(inp, *, tag, ranks, group_size):
567585

568586
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
569587
@torch._inductor.config.patch(debug=True)
588+
@run_with_legacy_funcol # impl specific
570589
def test_inductor_steal_buffer(self):
571590
"""
572591
it's ok and optimal if inductor allreduce mutates the buffer of an intermediate
@@ -604,6 +623,7 @@ def func(inp, *, tag, ranks, group_size):
604623

605624
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
606625
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
626+
@run_with_legacy_funcol # impl specific
607627
def test_inductor_doesnt_mutate_shared(self):
608628
"""
609629
make sure that an intermediate that's going to be reuse isn't mutated unless copied
@@ -641,40 +661,49 @@ def func(inp, *, tag, ranks, group_size):
641661
correct = func(inputs, **self.get_world_trs())
642662
self.assertTrue(same(out, correct))
643663

644-
def test_dynamo_trace_allreduce(self):
664+
@run_with_both_funcol_impls_with_arg
665+
def test_dynamo_trace_allreduce(self, use_native_funcol):
645666

646-
def func(inp, *, tag, ranks, group_size):
647-
ar = _functional_collectives.all_reduce(inp, "sum", ranks, tag)
667+
def func(inp):
668+
if use_native_funcol:
669+
ar = _functional_collectives.all_reduce(inp, "sum", "0")
670+
else:
671+
ar = _functional_collectives.all_reduce(inp, "sum", [0], "")
648672
return ar
649673

650674
inputs = torch.ones(4, 4, device="cuda")
651675
counter = CompileCounter()
652676
compiled = torch.compile(func, backend=counter)
653-
out = compiled(inputs, **self.get_world_trs())
654-
correct = func(inputs, **self.get_world_trs())
677+
out = compiled(inputs)
678+
correct = func(inputs)
655679
self.assertEqual(counter.frame_count, 1)
656680

657681
# should test more precisely, but the 2 is supposed to be (all_reduce, wait)
658682
self.assertEqual(counter.op_count, 2)
659683
self.assertTrue(same(out, correct))
660684

661-
def test_dynamo_trace_all_gather_tensor(self):
685+
@run_with_both_funcol_impls_with_arg
686+
def test_dynamo_trace_all_gather_tensor(self, use_native_funcol):
662687

663-
def func(inp, *, tag, ranks, group_size):
664-
ar = _functional_collectives.all_gather_tensor(inp, 0, ranks, tag)
688+
def func(inp):
689+
if use_native_funcol:
690+
ar = _functional_collectives.all_gather_tensor(inp, 0, "0")
691+
else:
692+
ar = _functional_collectives.all_gather_tensor(inp, 0, [0], "")
665693
return ar
666694

667695
inputs = torch.ones(4, 4, device="cuda")
668696
counter = CompileCounter()
669697
compiled = torch.compile(func, backend=counter)
670-
out = compiled(inputs, **self.get_world_trs())
671-
correct = func(inputs, **self.get_world_trs())
698+
out = compiled(inputs)
699+
correct = func(inputs)
672700
self.assertEqual(counter.frame_count, 1)
673701

674702
# should test more precisely, but the 2 is supposed to be (all_gather, wait)
675703
self.assertEqual(counter.op_count, 2)
676704
self.assertTrue(same(out, correct))
677705

706+
@run_with_both_funcol_impls
678707
def test_dynamo_trace_all_gather_tensor_pg(self):
679708

680709
def func(inp, *, pg):
@@ -692,6 +721,7 @@ def func(inp, *, pg):
692721
self.assertEqual(counter.op_count, 2)
693722
self.assertTrue(same(out, correct))
694723

724+
@run_with_both_funcol_impls
695725
def test_dynamo_rewrite_dist_all_gather(self):
696726

697727
def func(inp, out, *, pg):
@@ -717,6 +747,7 @@ def func(inp, out, *, pg):
717747
assert counter.op_count == 3
718748
assert same(outputs, correct_outputs)
719749

750+
@run_with_both_funcol_impls
720751
def test_dynamo_rewrite_dist_all_gather_list(self):
721752

722753
def func(inp, out, *, pg):
@@ -739,6 +770,7 @@ def func(inp, out, *, pg):
739770
assert counter.frame_count == 1
740771
assert same(outputs, correct_outputs)
741772

773+
@run_with_both_funcol_impls
742774
def test_dynamo_rewrite_dist_all_gather_args_match(self):
743775
# Duplicated most of the structure from test_dynamo_rewrite_dist_all_gather
744776
# except uses kwargs to ensure rewrite has matching arg names
@@ -766,6 +798,7 @@ def func(inp, out, *, pg):
766798
assert counter.op_count == 3
767799
assert same(outputs, correct_outputs)
768800

801+
@run_with_both_funcol_impls
769802
def test_dynamo_rewrite_dist_reduce_scatter(self):
770803

771804
def func(inp, out, *, pg):
@@ -791,6 +824,7 @@ def func(inp, out, *, pg):
791824
assert counter.op_count == 3
792825
assert same(outputs, correct_outputs)
793826

827+
@run_with_both_funcol_impls
794828
def test_dynamo_rewrite_dist_allreduce(self):
795829

796830
def func(tensor, pg):
@@ -813,6 +847,7 @@ def func(tensor, pg):
813847
assert counter.op_count == 3
814848
assert same(inputs_compiled, inputs_eager)
815849

850+
@run_with_both_funcol_impls
816851
def test_dynamo_rewrite_dist_all_to_all_single(self):
817852

818853
def func(output, input, pg):
@@ -836,6 +871,7 @@ def func(output, input, pg):
836871
assert counter.frame_count == 1
837872
assert same(output_compiled, output_eager)
838873

874+
@run_with_both_funcol_impls
839875
def test_dynamo_support_collective_op_with_async_op_False(self):
840876

841877
def func(inp, out, *, pg):
@@ -862,6 +898,7 @@ def func(inp, out, *, pg):
862898
assert counter.op_count == 3
863899
assert same(outputs, correct_outputs)
864900

901+
@run_with_both_funcol_impls
865902
def test_dynamo_graphbreaks_unsupported_async_op(self):
866903

867904
def func(inp, out, *, pg):
@@ -887,6 +924,7 @@ def func(inp, out, *, pg):
887924
assert counter.op_count == 0
888925
assert same(outputs, correct_outputs)
889926

927+
@run_with_both_funcol_impls
890928
def test_dynamo_pg_var(self):
891929
def func(inp, *, pg):
892930
x = pg.rank() + 1 % pg.size()
@@ -903,23 +941,28 @@ def func(inp, *, pg):
903941
assert counter.op_count == 1
904942
assert same(outputs, correct_outputs)
905943

906-
def test_dynamo_trace_reduce_scatter_tensor(self):
944+
@run_with_both_funcol_impls_with_arg
945+
def test_dynamo_trace_reduce_scatter_tensor(self, use_native_funcol):
907946

908-
def func(inp, *, tag, ranks, group_size):
909-
ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, ranks, tag)
947+
def func(inp):
948+
if use_native_funcol:
949+
ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, "0")
950+
else:
951+
ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, [0], "")
910952
return ar
911953

912954
inputs = torch.ones(4, 4, device="cuda")
913955
counter = CompileCounter()
914956
compiled = torch.compile(func, backend=counter)
915-
out = compiled(inputs, **self.get_world_trs())
916-
correct = func(inputs, **self.get_world_trs())
957+
out = compiled(inputs)
958+
correct = func(inputs)
917959
self.assertEqual(counter.frame_count, 1)
918960

919961
# should test more precisely, but the 2 is supposed to be (reduce_scatter, wait)
920962
self.assertEqual(counter.op_count, 2)
921963
self.assertTrue(same(out, correct))
922964

965+
@run_with_both_funcol_impls
923966
def test_dynamo_trace_allgather_coalesced(self):
924967
def func(inp, *, tag, ranks, group_size):
925968
ar = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(inp, tag, ranks, group_size)
@@ -935,25 +978,29 @@ def func(inp, *, tag, ranks, group_size):
935978
assert same(out, correct)
936979

937980

938-
def test_backwards(self):
981+
@run_with_both_funcol_impls_with_arg
982+
def test_backwards(self, use_native_funcol):
939983
"""
940984
It's probably not that common to need backwards support for collectives.
941985
942986
However, I wanted to at least see if it was possible to support it as a design goal.
943987
"""
944-
def func(inp, *, tag, ranks, group_size):
945-
ar = _functional_collectives.all_reduce(inp, "sum", ranks, tag)
988+
def func(inp):
989+
if use_native_funcol:
990+
ar = _functional_collectives.all_reduce(inp, "sum", "0")
991+
else:
992+
ar = _functional_collectives.all_reduce(inp, "sum", [0], "")
946993
return ar
947994

948995
input = torch.ones(4, 4, device="cuda", requires_grad=True)
949996
# TODO implement backwards
950997
with self.assertRaisesRegex(RuntimeError, "element 0 of tensors does not require grad and does not have a grad_fn"):
951998
compiled = torch.compile(func, backend="aot_eager") # inductor bug with single-op allreduce graph
952-
out = compiled(input, **self.get_world_trs())
999+
out = compiled(input)
9531000
out.sum().backward()
9541001

9551002
correct_input = input.clone().detach().requires_grad_()
956-
correct = func(correct_input, **self.get_world_trs())
1003+
correct = func(correct_input)
9571004
correct.sum().backward()
9581005
self.assertTrue(same(out, correct))
9591006
self.assertTrue(same(input.grad, correct_input.grad))
@@ -965,6 +1012,7 @@ def test_meta(self):
9651012

9661013
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
9671014
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
1015+
@run_with_legacy_funcol # impl specific
9681016
def test_inductor_all_gather_coalesced(self):
9691017
"""
9701018
make sure that an intermediate that's going to be reuse isn't mutated unless copied
@@ -1011,6 +1059,7 @@ def func(inp, *, tag, ranks, group_size):
10111059

10121060
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
10131061
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
1062+
@run_with_legacy_funcol # impl specific
10141063
def test_inductor_reduce_scatter_coalesced(self):
10151064
"""
10161065
make sure that an intermediate that's going to be reuse isn't mutated unless copied

torch/distributed/_functional_collectives.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,13 @@ def all_to_all_single(
480480
if native_funcol_enabled():
481481
group_name = _resolve_group_name(group, tag)
482482
group_size = c10d._get_group_size_by_name(group_name)
483+
if output_split_sizes is None or input_split_sizes is None:
484+
assert output_split_sizes is None and input_split_sizes is None, (
485+
"output_split_sizes and input_split_sizes must either be "
486+
"specified together or both set to None"
487+
)
488+
output_split_sizes = [self.shape[0] // group_size] * group_size
489+
input_split_sizes = output_split_sizes
483490
tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined]
484491
self,
485492
output_split_sizes,

0 commit comments

Comments
 (0)