Skip to content

Commit 58700c5

Browse files
committed
respect aten planned overlap in inductor
ghstack-source-id: fa58952 Pull Request resolved: #164569
1 parent 0d92d52 commit 58700c5

File tree

4 files changed

+143
-1
lines changed

4 files changed

+143
-1
lines changed

test/distributed/test_aten_comm_compute_reordering.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch.distributed._functional_collectives as _functional_collectives
1313
from torch._C import FileCheck
1414
from torch._dynamo.utils import counters, same
15-
from torch._inductor.utils import run_and_get_triton_code
15+
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
1616
from torch.testing._internal.common_distributed import (
1717
_dynamo_dist_per_rank_init,
1818
at_least_x_gpu,
@@ -67,6 +67,8 @@ def get_patches():
6767
"reorder_for_compute_comm_overlap_passes": [],
6868
"compile_threads": 1,
6969
"force_disable_caches": True,
70+
# Messes up existing test strings
71+
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
7072
}
7173

7274

@@ -358,6 +360,8 @@ def get_bucket_patches(compute_multiplier=1.0):
358360
"reorder_for_compute_comm_overlap_passes": [],
359361
"compile_threads": 1,
360362
"force_disable_caches": True,
363+
# messes up test strings
364+
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
361365
}
362366

363367

@@ -750,6 +754,91 @@ def func(a, b, c, *, ranks):
750754
correct = func(a, b, c, ranks=ranks)
751755
self.assertTrue(same(out, correct))
752756

757+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
758+
@torch._inductor.config.patch(get_bucket_patches(2.0))
759+
def test_bucketing_split_for_overlap_blocking_deps_inductor(self):
760+
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
761+
762+
# check that ordering is preserved in inductor
763+
764+
def func(a, b, c, d, *, ranks):
765+
# All 4 all-gathers are independent - COULD be bucketed together
766+
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
767+
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
768+
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
769+
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
770+
771+
# First compute - can hide ag1 and ag2
772+
e = a * 5 # Use a to avoid fusion
773+
mm1 = torch.matmul(e, e.T)
774+
775+
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
776+
# Use first 8x8 elements to match mm1's shape
777+
intermediate = ag1[:8, :8] + ag2[:8, :8]
778+
779+
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
780+
mm2 = torch.matmul(mm1 + intermediate, c[:8])
781+
782+
# Use all results
783+
result = (
784+
ag1.sum() * 1.1
785+
+ ag2.sum() * 1.2
786+
+ ag3.sum() * 1.3
787+
+ ag4.sum() * 1.4
788+
+ mm1.sum()
789+
+ mm2.sum()
790+
)
791+
return result
792+
793+
li = []
794+
apply = functools.partial(apply_reordering_and_get_graph, out_li=li)
795+
with (
796+
_dynamo_dist_per_rank_init(
797+
self.rank,
798+
self.world_size,
799+
self.backend(device_type),
800+
fake_pg=not at_least_x_gpu(2),
801+
),
802+
torch._inductor.config.patch(
803+
"test_configs.aten_fx_overlap_insert_overlap_deps", True
804+
),
805+
torch._inductor.config.patch(post_grad_custom_post_pass=apply),
806+
):
807+
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
808+
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
809+
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
810+
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
811+
ranks = list(range(self.world_size))
812+
813+
func_c = functools.partial(func, ranks=ranks)
814+
compiled = torch.compile(func_c)
815+
test_out, (code,) = run_and_get_code(compiled, a, b, c, d)
816+
817+
# Check that right deps are added
818+
f = FileCheck()
819+
for _ in range(2):
820+
f.check("control_deps_op").check_same("all_gather").check_same(
821+
"subgraph_mm"
822+
)
823+
f.check("control_deps_op").check_same("mm").check_same("subgraph_wait")
824+
f.run(li[0])
825+
826+
f = FileCheck()
827+
f.check("def call").check(
828+
"torch.ops._c10d_functional.all_gather_into_tensor"
829+
)
830+
f.check_count(".mm(", 1, exactly=True)
831+
f.check_count(".wait(", 1, exactly=True)
832+
f.check_count(
833+
"torch.ops._c10d_functional.all_gather_into_tensor_", 1, exactly=True
834+
)
835+
f.check_count(".mm(", 1, exactly=True)
836+
f.check_count(".wait(", 1, exactly=True)
837+
f.run(code)
838+
839+
correct = func(a, b, c, d, ranks=ranks)
840+
self.assertTrue(same(test_out, correct))
841+
753842

754843
if __name__ == "__main__":
755844
from torch._dynamo.test_case import run_tests

torch/_inductor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2025,6 +2025,9 @@ class test_configs:
20252025
# to be migrated when ready for use
20262026
aten_fx_overlap_scheduling = False
20272027

2028+
# insert ordering deps for overlap
2029+
aten_fx_overlap_insert_overlap_deps = True
2030+
20282031
# to be migrated when ready for use
20292032
aten_fx_overlap_preserving_bucketing = False
20302033

torch/_inductor/fx_passes/overlap_preserving_bucketer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def bucket_collectives(self) -> None:
8686
from torch._dynamo.graph_deduplication import _stable_topological_sort
8787

8888
_stable_topological_sort(self.graph, additional_deps)
89+
90+
# After topological sort, preserve dependencies using effect tokens
91+
self._preserve_dependencies_with_tokens(additional_deps)
92+
8993
self.graph.lint()
9094

9195
def _find_buckets(
@@ -254,3 +258,19 @@ def _apply_bucket(
254258
overlap_deps[new_wait].add(info.hiding_node)
255259

256260
return overlap_deps
261+
262+
def _preserve_dependencies_with_tokens(
263+
self, additional_deps: dict[fx.Node, OrderedSet[fx.Node]]
264+
) -> None:
265+
"""
266+
Preserve dependencies using effect tokens and with_effects higher-order op.
267+
268+
Uses the standalone token_dependencies utility for consistent behavior
269+
across different overlap scheduling approaches.
270+
"""
271+
from torch._inductor.fx_passes.control_dependencies import (
272+
preserve_node_ordering,
273+
)
274+
275+
if torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps:
276+
preserve_node_ordering(self.graph, additional_deps)

torch/_inductor/fx_passes/overlap_scheduling.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,40 @@ def run(self) -> torch.fx.GraphModule:
378378
self._handle_other(node)
379379

380380
self._reorder_graph()
381+
381382
if torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing:
382383
self._bucket_collectives()
384+
elif torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps:
385+
# If not bucketing, add effect tokens to preserve hiding dependencies
386+
self._add_effect_tokens_for_overlap()
387+
383388
return self.gm
384389

390+
def _add_effect_tokens_for_overlap(self) -> None:
391+
"""
392+
Add effect tokens to preserve hiding dependency relationships when not bucketing.
393+
394+
This ensures that communication-compute overlap is preserved through effect tokens
395+
when overlap preserving bucketing is not enabled.
396+
"""
397+
from torch._inductor.fx_passes.control_dependencies import (
398+
preserve_node_ordering,
399+
)
400+
401+
# Collect hiding dependencies: hiding_node -> collective_start, wait -> hiding_node
402+
additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
403+
404+
for start_node, info in self.collective_info.items():
405+
if info.hiding_node and not info.is_exposed:
406+
# Compute depends on collective start (compute must wait for collective to start)
407+
additional_deps[info.hiding_node].add(start_node)
408+
# Wait depends on compute (wait must wait for compute to finish)
409+
additional_deps[info.wait_node].add(info.hiding_node)
410+
411+
# Apply effect tokens to preserve these dependencies
412+
if additional_deps:
413+
preserve_node_ordering(self.graph, additional_deps)
414+
385415
def _handle_other(self, node: fx.Node) -> None:
386416
self._schedule(node)
387417

0 commit comments

Comments
 (0)