|
12 | 12 | import torch.distributed._functional_collectives as _functional_collectives |
13 | 13 | from torch._C import FileCheck |
14 | 14 | from torch._dynamo.utils import counters, same |
15 | | -from torch._inductor.utils import run_and_get_triton_code |
| 15 | +from torch._inductor.utils import run_and_get_code, run_and_get_triton_code |
16 | 16 | from torch.testing._internal.common_distributed import ( |
17 | 17 | _dynamo_dist_per_rank_init, |
18 | 18 | at_least_x_gpu, |
@@ -67,6 +67,8 @@ def get_patches(): |
67 | 67 | "reorder_for_compute_comm_overlap_passes": [], |
68 | 68 | "compile_threads": 1, |
69 | 69 | "force_disable_caches": True, |
| 70 | + # Messes up existing test strings |
| 71 | + "test_configs.aten_fx_overlap_insert_overlap_deps": False, |
70 | 72 | } |
71 | 73 |
|
72 | 74 |
|
@@ -358,6 +360,8 @@ def get_bucket_patches(compute_multiplier=1.0): |
358 | 360 | "reorder_for_compute_comm_overlap_passes": [], |
359 | 361 | "compile_threads": 1, |
360 | 362 | "force_disable_caches": True, |
| 363 | + # messes up test strings |
| 364 | + "test_configs.aten_fx_overlap_insert_overlap_deps": False, |
361 | 365 | } |
362 | 366 |
|
363 | 367 |
|
@@ -750,6 +754,85 @@ def func(a, b, c, *, ranks): |
750 | 754 | correct = func(a, b, c, ranks=ranks) |
751 | 755 | self.assertTrue(same(out, correct)) |
752 | 756 |
|
| 757 | + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") |
| 758 | + @torch._inductor.config.patch(get_bucket_patches(2.0)) |
| 759 | + def test_bucketing_split_for_overlap_blocking_deps_inductor(self): |
| 760 | + """Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute.""" |
| 761 | + |
| 762 | + # check that ordering is preserved in inductor |
| 763 | + |
| 764 | + def func(a, b, c, d, *, ranks): |
| 765 | + # All 4 all-gathers are independent - COULD be bucketed together |
| 766 | + ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) |
| 767 | + ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) |
| 768 | + ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks) |
| 769 | + ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks) |
| 770 | + |
| 771 | + # First compute - can hide ag1 and ag2 |
| 772 | + e = a * 5 # Use a to avoid fusion |
| 773 | + mm1 = torch.matmul(e, e.T) |
| 774 | + |
| 775 | + # Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred) |
| 776 | + # Use first 8x8 elements to match mm1's shape |
| 777 | + intermediate = ag1[:8, :8] + ag2[:8, :8] |
| 778 | + |
| 779 | + # Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4 |
| 780 | + mm2 = torch.matmul(mm1 + intermediate, c[:8]) |
| 781 | + |
| 782 | + # Use all results |
| 783 | + result = ( |
| 784 | + ag1.sum() * 1.1 |
| 785 | + + ag2.sum() * 1.2 |
| 786 | + + ag3.sum() * 1.3 |
| 787 | + + ag4.sum() * 1.4 |
| 788 | + + mm1.sum() |
| 789 | + + mm2.sum() |
| 790 | + ) |
| 791 | + return result |
| 792 | + |
| 793 | + li = [] |
| 794 | + apply = functools.partial(apply_reordering_and_get_graph, out_li=li) |
| 795 | + with ( |
| 796 | + _dynamo_dist_per_rank_init( |
| 797 | + self.rank, |
| 798 | + self.world_size, |
| 799 | + self.backend(device_type), |
| 800 | + fake_pg=not at_least_x_gpu(2), |
| 801 | + ), |
| 802 | + torch._inductor.config.patch( |
| 803 | + "test_configs.aten_fx_overlap_insert_overlap_deps", True |
| 804 | + ), |
| 805 | + torch._inductor.config.patch(post_grad_custom_post_pass=apply), |
| 806 | + ): |
| 807 | + a = torch.ones(8, 8, dtype=torch.float, device=device_type) |
| 808 | + b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2 |
| 809 | + c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3 |
| 810 | + d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4 |
| 811 | + ranks = list(range(self.world_size)) |
| 812 | + |
| 813 | + func_c = functools.partial(func, ranks=ranks) |
| 814 | + compiled = torch.compile(func_c) |
| 815 | + test_out, (code,) = run_and_get_code(compiled, a, b, c, d) |
| 816 | + |
| 817 | + # Check that right deps are added |
| 818 | + f = FileCheck() |
| 819 | + for _ in range(2): |
| 820 | + f.check("control_deps").check_same("all_gather").check_same( |
| 821 | + "subgraph_mm" |
| 822 | + ) |
| 823 | + f.check("control_deps").check_same("mm").check_same("subgraph_wait") |
| 824 | + f.run(li[0]) |
| 825 | + |
| 826 | + f = FileCheck() |
| 827 | + for _ in range(2): |
| 828 | + f.check_count("all_gather_into_tensor_out.default(", 1, exactly=True) |
| 829 | + f.check_count("extern_kernels.mm(", 1, exactly=True) |
| 830 | + f.check_count("wait_tensor.default(", 1, exactly=True) |
| 831 | + f.run(code) |
| 832 | + |
| 833 | + correct = func(a, b, c, d, ranks=ranks) |
| 834 | + self.assertTrue(same(test_out, correct)) |
| 835 | + |
753 | 836 |
|
754 | 837 | if __name__ == "__main__": |
755 | 838 | from torch._dynamo.test_case import run_tests |
|
0 commit comments