Skip to content

Commit 3ee7d6c

Browse files
committed
Update on "respect aten planned overlap in inductor"
Now that we have a hop to add implicit deps - use those deps for comm/compute overlap. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
2 parents f469c7c + f80a939 commit 3ee7d6c

File tree

3 files changed

+23
-24
lines changed

3 files changed

+23
-24
lines changed

test/distributed/test_aten_comm_compute_reordering.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -817,23 +817,17 @@ def func(a, b, c, d, *, ranks):
817817
# Check that right deps are added
818818
f = FileCheck()
819819
for _ in range(2):
820-
f.check("control_deps_op").check_same("all_gather").check_same(
820+
f.check("control_deps").check_same("all_gather").check_same(
821821
"subgraph_mm"
822822
)
823-
f.check("control_deps_op").check_same("mm").check_same("subgraph_wait")
823+
f.check("control_deps").check_same("mm").check_same("subgraph_wait")
824824
f.run(li[0])
825825

826826
f = FileCheck()
827-
f.check("def call").check(
828-
"torch.ops._c10d_functional.all_gather_into_tensor"
829-
)
830-
f.check_count(".mm(", 1, exactly=True)
831-
f.check_count(".wait(", 1, exactly=True)
832-
f.check_count(
833-
"torch.ops._c10d_functional.all_gather_into_tensor_", 1, exactly=True
834-
)
835-
f.check_count(".mm(", 1, exactly=True)
836-
f.check_count(".wait(", 1, exactly=True)
827+
for _ in range(2):
828+
f.check_count("all_gather_into_tensor_out.default(", 1, exactly=True)
829+
f.check_count("extern_kernels.mm(", 1, exactly=True)
830+
f.check_count("wait_tensor.default(", 1, exactly=True)
837831
f.run(code)
838832

839833
correct = func(a, b, c, d, ranks=ranks)

torch/_inductor/lowering.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7255,29 +7255,34 @@ def control_deps_op_lowering(additional_deps, subgraph_fn, *args):
72557255

72567256
output = None
72577257

7258+
operation_len = len(V.graph.operations)
72587259
assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args)
72597260
for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
72607261
if node.op == "placeholder":
7262+
assert node not in V.graph.env
72617263
V.graph.env[node] = args[i]
72627264
continue
72637265
elif node.op == "output":
72647266
args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
72657267
output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
72667268
else:
7269+
assert node not in V.graph.env
72677270
V.graph.env[node] = V.graph.run_node(node)
72687271

72697272
assert output is not None and additional_deps
7270-
output_list = output if isinstance(output, (list, tuple)) else [output]
72717273

7272-
for out in output_list:
7273-
if not isinstance(out, IRNode):
7274-
continue
7275-
7276-
# need to realize in order to add the dep
7277-
out.realize()
7278-
out_name = out.get_name()
7274+
# some operators, like wait_tensor, just return their input,
7275+
# so its more robust to add dep to the operation itself,
7276+
# otherwise you can have a cycle of
7277+
# a = coll
7278+
# b = control_deps(a, mm, ...)
7279+
# c = control_deps(b, wait, ...)
7280+
# if c == a, then you have a cycle.
7281+
for op in V.graph.operations[operation_len:]:
72797282
for dep_name in dep_names:
7280-
V.graph.additional_buffer_deps[out_name].add(dep_name)
7283+
op_name = op.operation_name
7284+
assert op_name is not None
7285+
V.graph.additional_buffer_deps[op_name].add(dep_name)
72817286

72827287
return output
72837288

torch/_inductor/scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2681,9 +2681,9 @@ def add_user(
26812681
)
26822682
add_user(other_name, node, is_weak=True)
26832683

2684-
for add_dep in V.graph.additional_buffer_deps[buf.get_name()]:
2685-
add_user(add_dep, node, is_weak=True)
2686-
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
2684+
for add_dep in V.graph.additional_buffer_deps[node.get_name()]:
2685+
add_user(add_dep, node, is_weak=True)
2686+
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
26872687

26882688
# add normal non-mutation dependencies
26892689
for read in node.read_writes.reads:

0 commit comments

Comments
 (0)