Skip to content

Commit 35f66b8

Browse files
eellisonpytorchmergebot
authored andcommitted
respect aten planned overlap in inductor (#164569)
Now that we have a hop to add implicit deps - use those deps for comm/compute overlap. Pull Request resolved: #164569 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: #164568
1 parent 4a39820 commit 35f66b8

File tree

6 files changed

+154
-13
lines changed

6 files changed

+154
-13
lines changed

test/distributed/test_aten_comm_compute_reordering.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch.distributed._functional_collectives as _functional_collectives
1313
from torch._C import FileCheck
1414
from torch._dynamo.utils import counters, same
15-
from torch._inductor.utils import run_and_get_triton_code
15+
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
1616
from torch.testing._internal.common_distributed import (
1717
_dynamo_dist_per_rank_init,
1818
at_least_x_gpu,
@@ -67,6 +67,8 @@ def get_patches():
6767
"reorder_for_compute_comm_overlap_passes": [],
6868
"compile_threads": 1,
6969
"force_disable_caches": True,
70+
# Messes up existing test strings
71+
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
7072
}
7173

7274

@@ -358,6 +360,8 @@ def get_bucket_patches(compute_multiplier=1.0):
358360
"reorder_for_compute_comm_overlap_passes": [],
359361
"compile_threads": 1,
360362
"force_disable_caches": True,
363+
# messes up test strings
364+
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
361365
}
362366

363367

@@ -750,6 +754,85 @@ def func(a, b, c, *, ranks):
750754
correct = func(a, b, c, ranks=ranks)
751755
self.assertTrue(same(out, correct))
752756

757+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
758+
@torch._inductor.config.patch(get_bucket_patches(2.0))
759+
def test_bucketing_split_for_overlap_blocking_deps_inductor(self):
760+
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
761+
762+
# check that ordering is preserved in inductor
763+
764+
def func(a, b, c, d, *, ranks):
765+
# All 4 all-gathers are independent - COULD be bucketed together
766+
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
767+
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
768+
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
769+
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
770+
771+
# First compute - can hide ag1 and ag2
772+
e = a * 5 # Use a to avoid fusion
773+
mm1 = torch.matmul(e, e.T)
774+
775+
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
776+
# Use first 8x8 elements to match mm1's shape
777+
intermediate = ag1[:8, :8] + ag2[:8, :8]
778+
779+
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
780+
mm2 = torch.matmul(mm1 + intermediate, c[:8])
781+
782+
# Use all results
783+
result = (
784+
ag1.sum() * 1.1
785+
+ ag2.sum() * 1.2
786+
+ ag3.sum() * 1.3
787+
+ ag4.sum() * 1.4
788+
+ mm1.sum()
789+
+ mm2.sum()
790+
)
791+
return result
792+
793+
li = []
794+
apply = functools.partial(apply_reordering_and_get_graph, out_li=li)
795+
with (
796+
_dynamo_dist_per_rank_init(
797+
self.rank,
798+
self.world_size,
799+
self.backend(device_type),
800+
fake_pg=not at_least_x_gpu(2),
801+
),
802+
torch._inductor.config.patch(
803+
"test_configs.aten_fx_overlap_insert_overlap_deps", True
804+
),
805+
torch._inductor.config.patch(post_grad_custom_post_pass=apply),
806+
):
807+
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
808+
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
809+
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
810+
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
811+
ranks = list(range(self.world_size))
812+
813+
func_c = functools.partial(func, ranks=ranks)
814+
compiled = torch.compile(func_c)
815+
test_out, (code,) = run_and_get_code(compiled, a, b, c, d)
816+
817+
# Check that right deps are added
818+
f = FileCheck()
819+
for _ in range(2):
820+
f.check("control_deps").check_same("all_gather").check_same(
821+
"subgraph_mm"
822+
)
823+
f.check("control_deps").check_same("mm").check_same("subgraph_wait")
824+
f.run(li[0])
825+
826+
f = FileCheck()
827+
for _ in range(2):
828+
f.check_count("all_gather_into_tensor_out.default(", 1, exactly=True)
829+
f.check_count("extern_kernels.mm(", 1, exactly=True)
830+
f.check_count("wait_tensor.default(", 1, exactly=True)
831+
f.run(code)
832+
833+
correct = func(a, b, c, d, ranks=ranks)
834+
self.assertTrue(same(test_out, correct))
835+
753836

754837
if __name__ == "__main__":
755838
from torch._dynamo.test_case import run_tests

torch/_inductor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,6 +2029,9 @@ class test_configs:
20292029
# to be migrated when ready for use
20302030
aten_fx_overlap_scheduling = False
20312031

2032+
# insert ordering deps for overlap
2033+
aten_fx_overlap_insert_overlap_deps = True
2034+
20322035
# to be migrated when ready for use
20332036
aten_fx_overlap_preserving_bucketing = False
20342037

torch/_inductor/fx_passes/overlap_preserving_bucketer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def bucket_collectives(self) -> None:
8686
from torch._dynamo.graph_deduplication import _stable_topological_sort
8787

8888
_stable_topological_sort(self.graph, additional_deps)
89+
90+
# After topological sort, preserve dependencies using effect tokens
91+
self._preserve_dependencies_with_tokens(additional_deps)
92+
8993
self.graph.lint()
9094

9195
def _find_buckets(
@@ -254,3 +258,19 @@ def _apply_bucket(
254258
overlap_deps[new_wait].add(info.hiding_node)
255259

256260
return overlap_deps
261+
262+
def _preserve_dependencies_with_tokens(
263+
self, additional_deps: dict[fx.Node, OrderedSet[fx.Node]]
264+
) -> None:
265+
"""
266+
Preserve dependencies using effect tokens and with_effects higher-order op.
267+
268+
Uses the standalone token_dependencies utility for consistent behavior
269+
across different overlap scheduling approaches.
270+
"""
271+
from torch._inductor.fx_passes.control_dependencies import (
272+
preserve_node_ordering,
273+
)
274+
275+
if torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps:
276+
preserve_node_ordering(self.graph, additional_deps)

torch/_inductor/fx_passes/overlap_scheduling.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,40 @@ def run(self) -> torch.fx.GraphModule:
378378
self._handle_other(node)
379379

380380
self._reorder_graph()
381+
381382
if torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing:
382383
self._bucket_collectives()
384+
elif torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps:
385+
# If not bucketing, add effect tokens to preserve hiding dependencies
386+
self._add_effect_tokens_for_overlap()
387+
383388
return self.gm
384389

390+
def _add_effect_tokens_for_overlap(self) -> None:
391+
"""
392+
Add effect tokens to preserve hiding dependency relationships when not bucketing.
393+
394+
This ensures that communication-compute overlap is preserved through effect tokens
395+
when overlap preserving bucketing is not enabled.
396+
"""
397+
from torch._inductor.fx_passes.control_dependencies import (
398+
preserve_node_ordering,
399+
)
400+
401+
# Collect hiding dependencies: hiding_node -> collective_start, wait -> hiding_node
402+
additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
403+
404+
for start_node, info in self.collective_info.items():
405+
if info.hiding_node and not info.is_exposed:
406+
# Compute depends on collective start (compute must wait for collective to start)
407+
additional_deps[info.hiding_node].add(start_node)
408+
# Wait depends on compute (wait must wait for compute to finish)
409+
additional_deps[info.wait_node].add(info.hiding_node)
410+
411+
# Apply effect tokens to preserve these dependencies
412+
if additional_deps:
413+
preserve_node_ordering(self.graph, additional_deps)
414+
385415
def _handle_other(self, node: fx.Node) -> None:
386416
self._schedule(node)
387417

torch/_inductor/lowering.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7255,29 +7255,34 @@ def control_deps_op_lowering(additional_deps, subgraph_fn, *args):
72557255

72567256
output = None
72577257

7258+
operation_len = len(V.graph.operations)
72587259
assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args)
72597260
for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
72607261
if node.op == "placeholder":
7262+
assert node not in V.graph.env
72617263
V.graph.env[node] = args[i]
72627264
continue
72637265
elif node.op == "output":
72647266
args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
72657267
output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
72667268
else:
7269+
assert node not in V.graph.env
72677270
V.graph.env[node] = V.graph.run_node(node)
72687271

72697272
assert output is not None and additional_deps
7270-
output_list = output if isinstance(output, (list, tuple)) else [output]
72717273

7272-
for out in output_list:
7273-
if not isinstance(out, IRNode):
7274-
continue
7275-
7276-
# need to realize in order to add the dep
7277-
out.realize()
7278-
out_name = out.get_name()
7274+
# some operators, like wait_tensor, just return their input,
7275+
# so its more robust to add dep to the operation itself,
7276+
# otherwise you can have a cycle of
7277+
# a = coll
7278+
# b = control_deps(a, mm, ...)
7279+
# c = control_deps(b, wait, ...)
7280+
# if c == a, then you have a cycle.
7281+
for op in V.graph.operations[operation_len:]:
72797282
for dep_name in dep_names:
7280-
V.graph.additional_buffer_deps[out_name].add(dep_name)
7283+
op_name = op.operation_name
7284+
assert op_name is not None
7285+
V.graph.additional_buffer_deps[op_name].add(dep_name)
72817286

72827287
return output
72837288

torch/_inductor/scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2684,9 +2684,9 @@ def add_user(
26842684
)
26852685
add_user(other_name, node, is_weak=True)
26862686

2687-
for add_dep in V.graph.additional_buffer_deps[buf.get_name()]:
2688-
add_user(add_dep, node, is_weak=True)
2689-
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
2687+
for add_dep in V.graph.additional_buffer_deps[node.get_name()]:
2688+
add_user(add_dep, node, is_weak=True)
2689+
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
26902690

26912691
# add normal non-mutation dependencies
26922692
for read in node.read_writes.reads:

0 commit comments

Comments
 (0)