Skip to content

Commit 415e641

Browse files
eellisonpytorchmergebot
authored andcommitted
Limit path search within range (#164581)
When we are looking if two nodes are dependent, limit path search within the bounds of their node idxs. Pull Request resolved: #164581 Approved by: https://github.com/ezyang ghstack dependencies: #164568, #164569
1 parent 11f5f65 commit 415e641

File tree

4 files changed

+155
-49
lines changed

4 files changed

+155
-49
lines changed

test/inductor/test_augmented_graph_helper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,35 @@ def test_multiple_merge_unmerge(self):
339339
self.assertEqual(self.tracker.merge_sets[nodes[0]], {nodes[0]})
340340
self.assertEqual(len(self.tracker.merge_sets[nodes[1]]), 1)
341341

342+
def test_has_path_with_bounded_search(self):
343+
"""Test that bounded search correctly respects search range bounds."""
344+
# Create a simple linear chain: x -> A -> B -> C -> D
345+
graph = fx.Graph()
346+
x = graph.placeholder("x")
347+
a = graph.call_function(torch.neg, args=(x,), name="A")
348+
b = graph.call_function(torch.abs, args=(a,), name="B")
349+
c = graph.call_function(torch.relu, args=(b,), name="C")
350+
d = graph.call_function(torch.sigmoid, args=(c,), name="D")
351+
graph.output(d)
352+
353+
node_to_idx = {node: idx for idx, node in enumerate(graph.nodes)}
354+
tracker = AugmentedGraphHelper(graph, node_to_idx=node_to_idx)
355+
356+
# Path exists from A to D: A -> B -> C -> D
357+
self.assertTrue(tracker.has_path(a, d))
358+
359+
# Test with correct bounds: include all nodes in the path
360+
a_idx = node_to_idx[a]
361+
d_idx = node_to_idx[d]
362+
# Bounds that include the full path should find it
363+
self.assertTrue(tracker.has_path(a, d, bounded_search_range=(a_idx, d_idx)))
364+
365+
# Test with incorrect bounds: exclude critical intermediate nodes
366+
c_idx = node_to_idx[c]
367+
# Bounds that exclude A and B (only allowing C and D) should NOT find the path
368+
# because the search can't reach back to A
369+
self.assertFalse(tracker.has_path(a, d, bounded_search_range=(c_idx, d_idx)))
370+
342371

343372
if __name__ == "__main__":
344373
from torch._inductor.test_case import run_tests

torch/_inductor/augmented_graph_helper.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import defaultdict
2+
from typing import Optional
23

34
import torch
45
import torch.fx as fx
@@ -9,18 +10,20 @@ class AugmentedGraphHelper:
910
"""
1011
Graph helper that augments the original graph with additional
1112
dependencies and uses, plus tracks node equivalences for coalescing.
12-
1313
TODO: if this becomes too large of compile time, consider binding
1414
graphcycles.cc
1515
"""
1616

17-
def __init__(self, graph: fx.Graph):
17+
def __init__(
18+
self, graph: fx.Graph, node_to_idx: Optional[dict[fx.Node, int]] = None
19+
):
1820
# Each node starts in its own singleton set
1921
self.graph = graph
2022
self.merge_sets = {node: OrderedSet([node]) for node in graph.nodes}
21-
2223
# Extra dependencies: node depends on dep (dep must come before node)
2324
self.extra_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
25+
# Optional node to index mapping for bounded searches
26+
self.node_to_idx = node_to_idx
2427

2528
def add_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None:
2629
"""Add extra dependency: node depends on dep."""
@@ -33,25 +36,20 @@ def merge_to_set(self, existing_node: fx.Node, new_node: fx.Node) -> None:
3336
existing_set = self.merge_sets[existing_node]
3437
new_set = self.merge_sets[new_node]
3538
assert len(new_set) == 1
36-
3739
# Add all nodes from new_set to existing_set
3840
existing_set.update(new_set)
39-
4041
# Update all nodes from new_set to point to existing_set
4142
for node in new_set:
4243
self.merge_sets[node] = existing_set
4344

4445
def unmerge_node(self, node: fx.Node) -> None:
4546
"""Remove a node from its merge set, making it singleton."""
4647
old_set = self.merge_sets[node]
47-
4848
# If already singleton, nothing to do
4949
if len(old_set) == 1:
5050
return
51-
5251
# Remove from old set
5352
old_set.remove(node)
54-
5553
# Make node singleton
5654
self.merge_sets[node] = OrderedSet([node])
5755

@@ -63,22 +61,29 @@ def get_merged_deps(self, node: fx.Node) -> OrderedSet[fx.Node]:
6361
2. Extra deps of node and its merge equivalents
6462
"""
6563
deps: OrderedSet[fx.Node] = OrderedSet()
66-
6764
# For each node in the merge set
6865
for merged_node in self.merge_sets[node]:
6966
# Add direct dependencies from all_input_nodes
7067
deps.update(merged_node.all_input_nodes)
7168
# Add extra dependencies
7269
deps.update(self.extra_deps[merged_node])
73-
7470
return deps
7571

7672
def has_cycle(self) -> bool:
7773
merged_deps = {n: self.get_merged_deps(n) for n in self.graph.nodes}
7874
return torch._dynamo.graph_deduplication._has_cycle(self.graph, merged_deps)
7975

80-
def has_path(self, source: fx.Node, target: fx.Node) -> bool:
81-
"""Check if there's a path from source to target."""
76+
def has_path(
77+
self,
78+
source: fx.Node,
79+
target: fx.Node,
80+
bounded_search_range: Optional[tuple[int, int]] = None,
81+
) -> bool:
82+
"""
83+
Check if there's a path from source to target.
84+
85+
If bounds are provided, only searches nodes within the idx of these ranges.
86+
"""
8287
# we should not be checking path from node to itself
8388
assert self.merge_sets[source] is not self.merge_sets[target]
8489

@@ -92,6 +97,14 @@ def has_path(self, source: fx.Node, target: fx.Node) -> bool:
9297

9398
# Get all dependencies
9499
for dep in self.get_merged_deps(current):
100+
# If using bounds, skip nodes outside the range
101+
if bounded_search_range is not None and self.node_to_idx is not None:
102+
min_idx, max_idx = bounded_search_range
103+
dep_idx = self.node_to_idx.get(dep)
104+
assert dep_idx is not None
105+
if dep_idx < min_idx or dep_idx > max_idx:
106+
continue
107+
95108
# Check if we reached source or its equivalent
96109
if dep in self.merge_sets[source]:
97110
return True

torch/_inductor/fx_passes/overlap_preserving_bucketer.py

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import defaultdict
2+
from dataclasses import dataclass, field
23
from typing import Optional
34

45
import torch
@@ -11,10 +12,58 @@
1112
is_reduce_scatter_tensor as is_reduce_scatter,
1213
is_wait_tensor,
1314
)
14-
from torch._inductor.fx_passes.overlap_scheduling import CollBucket, CollectiveInfo
15+
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
1516
from torch.utils._ordered_set import OrderedSet
1617

1718

19+
@dataclass(slots=True)
20+
class CollBucket:
21+
"""Track information about a bucket of collectives."""
22+
23+
collectives: list[fx.Node] = field(
24+
default_factory=list
25+
) # Original collective starts
26+
total_bytes: int = 0
27+
min_start_idx: Optional[int] = None # Minimum index of collective starts
28+
max_wait_idx: Optional[int] = None # Maximum index of collective waits
29+
30+
bucketed_start: Optional[fx.Node] = None # After bucketing
31+
bucketed_wait: Optional[fx.Node] = None # After bucketing
32+
33+
def add_collective(
34+
self,
35+
coll_info: CollectiveInfo,
36+
node_idx: dict[fx.Node, int],
37+
) -> None:
38+
"""
39+
Add a collective to this bucket and update bucket metadata.
40+
41+
This handles all updates needed when adding a collective:
42+
- Appends to collectives list
43+
- Updates total bytes
44+
- Updates min_start_idx and max_wait_idx
45+
"""
46+
collective = coll_info.start_node
47+
48+
# Add to bucket
49+
self.collectives.append(collective)
50+
self.total_bytes += coll_info.size_bytes
51+
52+
# Update min start index
53+
start_idx = node_idx[collective]
54+
if self.min_start_idx is None:
55+
self.min_start_idx = start_idx
56+
else:
57+
self.min_start_idx = min(self.min_start_idx, start_idx)
58+
59+
# Update max wait index
60+
wait_idx = node_idx[coll_info.wait_node]
61+
if self.max_wait_idx is None:
62+
self.max_wait_idx = wait_idx
63+
else:
64+
self.max_wait_idx = max(self.max_wait_idx, wait_idx)
65+
66+
1867
def bucket_key(node: torch.fx.Node) -> Optional[object]:
1968
if is_all_gather(node):
2069
return _ag_group_key(node)
@@ -44,20 +93,19 @@ def __init__(
4493
self.scheduled = scheduled
4594
self.max_bucket_memory_gb = max_bucket_memory_gb
4695
self.node_idx = {n: i for i, n in enumerate(scheduled)}
96+
self.aug_graph = AugmentedGraphHelper(self.graph, node_to_idx=self.node_idx)
4797

4898
def bucket_collectives(self) -> None:
4999
"""Main entry point for bucketing collectives."""
50100

51-
aug_graph = AugmentedGraphHelper(self.graph)
52-
53101
# Add extra dependencies for hidden collectives
54102
# For each hidden collective, add: compute -> start and wait -> compute
55103
for start_node, info in self.collective_info.items():
56104
if info.hiding_node and not info.is_exposed:
57105
# Add edge: hiding_compute depends on start (start must come before compute)
58-
aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node)
106+
self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node)
59107
# Add edge: wait depends on hiding_compute (compute must come before wait)
60-
aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
108+
self.aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
61109

62110
# Group collectives by bucket key (type, group, etc.)
63111
grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
@@ -68,7 +116,7 @@ def bucket_collectives(self) -> None:
68116

69117
all_buckets: list[CollBucket] = []
70118
for collective_group in grouped_collectives.values():
71-
buckets = self._find_buckets(collective_group, aug_graph)
119+
buckets = self._find_buckets(collective_group)
72120
all_buckets.extend(buckets)
73121

74122
# Collect all extra dependencies to preserve after bucketing
@@ -95,7 +143,6 @@ def bucket_collectives(self) -> None:
95143
def _find_buckets(
96144
self,
97145
collective_group: OrderedSet[fx.Node],
98-
aug_graph: AugmentedGraphHelper,
99146
) -> list[CollBucket]:
100147
"""Find valid buckets within a group of similar collectives."""
101148

@@ -108,13 +155,10 @@ def _find_buckets(
108155
continue
109156

110157
# Initialize bucket with first collective
111-
bucket_info = CollBucket(
112-
collectives=[start_node],
113-
total_bytes=self.collective_info[start_node].size_bytes,
114-
)
158+
bucket_info = CollBucket()
159+
bucket_info.add_collective(self.collective_info[start_node], self.node_idx)
115160
processed.add(start_node)
116161

117-
# TODO - limit within range
118162
for candidate in collective_group:
119163
if candidate in processed:
120164
continue
@@ -123,9 +167,10 @@ def _find_buckets(
123167
if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes:
124168
continue
125169

126-
if self._can_add_to_bucket(bucket_info, candidate, aug_graph):
127-
bucket_info.collectives.append(candidate)
128-
bucket_info.total_bytes += candidate_bytes
170+
if self._can_add_to_bucket(bucket_info, candidate):
171+
bucket_info.add_collective(
172+
self.collective_info[candidate], self.node_idx
173+
)
129174
processed.add(candidate)
130175

131176
if len(bucket_info.collectives) > 1:
@@ -137,11 +182,30 @@ def _ancestor_dep(self, n1: fx.Node, n2: fx.Node) -> bool:
137182
"""Check if there's an ancestor relationship between two nodes."""
138183
return n1 in self.node_ancestors[n2] or n2 in self.node_ancestors[n1]
139184

185+
def _has_path(
186+
self,
187+
source: fx.Node,
188+
source_bounds: tuple[int, int],
189+
target: fx.Node,
190+
target_bounds: tuple[int, int],
191+
) -> bool:
192+
"""Check if there's a path from source to target with bounded search."""
193+
194+
search_range = (
195+
min(source_bounds[0], target_bounds[0]),
196+
max(source_bounds[1], target_bounds[1]),
197+
)
198+
199+
return self.aug_graph.has_path(
200+
source,
201+
target,
202+
bounded_search_range=search_range,
203+
)
204+
140205
def _can_add_to_bucket(
141206
self,
142207
bucket_info: CollBucket,
143208
candidate: fx.Node,
144-
aug_graph: AugmentedGraphHelper,
145209
) -> bool:
146210
"""
147211
Check if candidate can be added to bucket without interfering
@@ -174,29 +238,39 @@ def _can_add_to_bucket(
174238
# Check if there's a path between any existing start and candidate start.
175239
# Because the collectives have already been merged, we can just start from one
176240
# of them.
177-
# TODO: we have a range of possible idxs of the merged node, and idx of new node.
178-
# we should not do path search beyond that range
179241
existing_coll = bucket_info.collectives[0]
180-
if aug_graph.has_path(existing_coll, candidate):
242+
243+
# Calculate bounds for path search
244+
candidate_idx = self.node_idx[candidate]
245+
candidate_wait_idx = self.node_idx[candidate_wait]
246+
247+
bucket_min_idx = bucket_info.min_start_idx
248+
bucket_max_idx = bucket_info.max_wait_idx
249+
assert bucket_min_idx is not None and bucket_max_idx is not None
250+
existing_bounds = (bucket_min_idx, bucket_max_idx)
251+
candidate_bounds = (candidate_idx, candidate_wait_idx)
252+
253+
if self._has_path(existing_coll, existing_bounds, candidate, candidate_bounds):
181254
return False
182-
if aug_graph.has_path(candidate, existing_coll):
255+
if self._has_path(candidate, candidate_bounds, existing_coll, existing_bounds):
183256
return False
184257

185258
# Safe to merge starts - do the merge
186-
aug_graph.merge_to_set(existing_coll, candidate)
259+
self.aug_graph.merge_to_set(existing_coll, candidate)
187260

188261
# Step 3: Check and merge waits
189262
existing_wait = self.collective_info[existing_coll].wait_node
190-
candidate_wait = candidate_info.wait_node
191-
# TODO - as above, limit search by idx
192-
if aug_graph.has_path(existing_wait, candidate_wait) or aug_graph.has_path(
193-
candidate_wait, existing_wait
263+
264+
if self._has_path(
265+
existing_wait, existing_bounds, candidate_wait, candidate_bounds
266+
) or self._has_path(
267+
candidate_wait, candidate_bounds, existing_wait, existing_bounds
194268
):
195269
# Unmerge the start we just merged
196-
aug_graph.unmerge_node(candidate)
270+
self.aug_graph.unmerge_node(candidate)
197271
return False
198272

199-
aug_graph.merge_to_set(existing_wait, candidate_wait)
273+
self.aug_graph.merge_to_set(existing_wait, candidate_wait)
200274
return True
201275

202276
def _apply_bucket(

torch/_inductor/fx_passes/overlap_scheduling.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,16 +157,6 @@ def is_exposed(self) -> bool:
157157
return self.exposed_time_ms != 0
158158

159159

160-
@dataclass
161-
class CollBucket:
162-
"""Track information about a bucket of collectives."""
163-
164-
collectives: list[fx.Node] # Original collective starts
165-
bucketed_start: Optional[fx.Node] = None # After bucketing
166-
bucketed_wait: Optional[fx.Node] = None # After bucketing
167-
total_bytes: int = 0
168-
169-
170160
class OverlapScheduler:
171161
"""
172162
Scheduler that reorders operations to maximize compute-collective overlap.

0 commit comments

Comments
 (0)