11from collections import defaultdict
2+ from dataclasses import dataclass , field
23from typing import Optional
34
45import torch
1112 is_reduce_scatter_tensor as is_reduce_scatter ,
1213 is_wait_tensor ,
1314)
14- from torch ._inductor .fx_passes .overlap_scheduling import CollBucket , CollectiveInfo
15+ from torch ._inductor .fx_passes .overlap_scheduling import CollectiveInfo
1516from torch .utils ._ordered_set import OrderedSet
1617
1718
19+ @dataclass (slots = True )
20+ class CollBucket :
21+ """Track information about a bucket of collectives."""
22+
23+ collectives : list [fx .Node ] = field (
24+ default_factory = list
25+ ) # Original collective starts
26+ total_bytes : int = 0
27+ min_start_idx : Optional [int ] = None # Minimum index of collective starts
28+ max_wait_idx : Optional [int ] = None # Maximum index of collective waits
29+
30+ bucketed_start : Optional [fx .Node ] = None # After bucketing
31+ bucketed_wait : Optional [fx .Node ] = None # After bucketing
32+
33+ def add_collective (
34+ self ,
35+ coll_info : CollectiveInfo ,
36+ node_idx : dict [fx .Node , int ],
37+ ) -> None :
38+ """
39+ Add a collective to this bucket and update bucket metadata.
40+
41+ This handles all updates needed when adding a collective:
42+ - Appends to collectives list
43+ - Updates total bytes
44+ - Updates min_start_idx and max_wait_idx
45+ """
46+ collective = coll_info .start_node
47+
48+ # Add to bucket
49+ self .collectives .append (collective )
50+ self .total_bytes += coll_info .size_bytes
51+
52+ # Update min start index
53+ start_idx = node_idx [collective ]
54+ if self .min_start_idx is None :
55+ self .min_start_idx = start_idx
56+ else :
57+ self .min_start_idx = min (self .min_start_idx , start_idx )
58+
59+ # Update max wait index
60+ wait_idx = node_idx [coll_info .wait_node ]
61+ if self .max_wait_idx is None :
62+ self .max_wait_idx = wait_idx
63+ else :
64+ self .max_wait_idx = max (self .max_wait_idx , wait_idx )
65+
66+
1867def bucket_key (node : torch .fx .Node ) -> Optional [object ]:
1968 if is_all_gather (node ):
2069 return _ag_group_key (node )
@@ -44,20 +93,19 @@ def __init__(
4493 self .scheduled = scheduled
4594 self .max_bucket_memory_gb = max_bucket_memory_gb
4695 self .node_idx = {n : i for i , n in enumerate (scheduled )}
96+ self .aug_graph = AugmentedGraphHelper (self .graph , node_to_idx = self .node_idx )
4797
4898 def bucket_collectives (self ) -> None :
4999 """Main entry point for bucketing collectives."""
50100
51- aug_graph = AugmentedGraphHelper (self .graph )
52-
53101 # Add extra dependencies for hidden collectives
54102 # For each hidden collective, add: compute -> start and wait -> compute
55103 for start_node , info in self .collective_info .items ():
56104 if info .hiding_node and not info .is_exposed :
57105 # Add edge: hiding_compute depends on start (start must come before compute)
58- aug_graph .add_extra_dep (n = info .hiding_node , dep = start_node )
106+ self . aug_graph .add_extra_dep (n = info .hiding_node , dep = start_node )
59107 # Add edge: wait depends on hiding_compute (compute must come before wait)
60- aug_graph .add_extra_dep (n = info .wait_node , dep = info .hiding_node )
108+ self . aug_graph .add_extra_dep (n = info .wait_node , dep = info .hiding_node )
61109
62110 # Group collectives by bucket key (type, group, etc.)
63111 grouped_collectives : dict [object , OrderedSet [fx .Node ]] = defaultdict (OrderedSet )
@@ -68,7 +116,7 @@ def bucket_collectives(self) -> None:
68116
69117 all_buckets : list [CollBucket ] = []
70118 for collective_group in grouped_collectives .values ():
71- buckets = self ._find_buckets (collective_group , aug_graph )
119+ buckets = self ._find_buckets (collective_group )
72120 all_buckets .extend (buckets )
73121
74122 # Collect all extra dependencies to preserve after bucketing
@@ -95,7 +143,6 @@ def bucket_collectives(self) -> None:
95143 def _find_buckets (
96144 self ,
97145 collective_group : OrderedSet [fx .Node ],
98- aug_graph : AugmentedGraphHelper ,
99146 ) -> list [CollBucket ]:
100147 """Find valid buckets within a group of similar collectives."""
101148
@@ -108,13 +155,10 @@ def _find_buckets(
108155 continue
109156
110157 # Initialize bucket with first collective
111- bucket_info = CollBucket (
112- collectives = [start_node ],
113- total_bytes = self .collective_info [start_node ].size_bytes ,
114- )
158+ bucket_info = CollBucket ()
159+ bucket_info .add_collective (self .collective_info [start_node ], self .node_idx )
115160 processed .add (start_node )
116161
117- # TODO - limit within range
118162 for candidate in collective_group :
119163 if candidate in processed :
120164 continue
@@ -123,9 +167,10 @@ def _find_buckets(
123167 if bucket_info .total_bytes + candidate_bytes > max_bucket_bytes :
124168 continue
125169
126- if self ._can_add_to_bucket (bucket_info , candidate , aug_graph ):
127- bucket_info .collectives .append (candidate )
128- bucket_info .total_bytes += candidate_bytes
170+ if self ._can_add_to_bucket (bucket_info , candidate ):
171+ bucket_info .add_collective (
172+ self .collective_info [candidate ], self .node_idx
173+ )
129174 processed .add (candidate )
130175
131176 if len (bucket_info .collectives ) > 1 :
@@ -137,11 +182,30 @@ def _ancestor_dep(self, n1: fx.Node, n2: fx.Node) -> bool:
137182 """Check if there's an ancestor relationship between two nodes."""
138183 return n1 in self .node_ancestors [n2 ] or n2 in self .node_ancestors [n1 ]
139184
185+ def _has_path (
186+ self ,
187+ source : fx .Node ,
188+ source_bounds : tuple [int , int ],
189+ target : fx .Node ,
190+ target_bounds : tuple [int , int ],
191+ ) -> bool :
192+ """Check if there's a path from source to target with bounded search."""
193+
194+ search_range = (
195+ min (source_bounds [0 ], target_bounds [0 ]),
196+ max (source_bounds [1 ], target_bounds [1 ]),
197+ )
198+
199+ return self .aug_graph .has_path (
200+ source ,
201+ target ,
202+ bounded_search_range = search_range ,
203+ )
204+
140205 def _can_add_to_bucket (
141206 self ,
142207 bucket_info : CollBucket ,
143208 candidate : fx .Node ,
144- aug_graph : AugmentedGraphHelper ,
145209 ) -> bool :
146210 """
147211 Check if candidate can be added to bucket without interfering
@@ -174,29 +238,39 @@ def _can_add_to_bucket(
174238 # Check if there's a path between any existing start and candidate start.
175239 # Because the collectives have already been merged, we can just start from one
176240 # of them.
177- # TODO: we have a range of possible idxs of the merged node, and idx of new node.
178- # we should not do path search beyond that range
179241 existing_coll = bucket_info .collectives [0 ]
180- if aug_graph .has_path (existing_coll , candidate ):
242+
243+ # Calculate bounds for path search
244+ candidate_idx = self .node_idx [candidate ]
245+ candidate_wait_idx = self .node_idx [candidate_wait ]
246+
247+ bucket_min_idx = bucket_info .min_start_idx
248+ bucket_max_idx = bucket_info .max_wait_idx
249+ assert bucket_min_idx is not None and bucket_max_idx is not None
250+ existing_bounds = (bucket_min_idx , bucket_max_idx )
251+ candidate_bounds = (candidate_idx , candidate_wait_idx )
252+
253+ if self ._has_path (existing_coll , existing_bounds , candidate , candidate_bounds ):
181254 return False
182- if aug_graph . has_path (candidate , existing_coll ):
255+ if self . _has_path (candidate , candidate_bounds , existing_coll , existing_bounds ):
183256 return False
184257
185258 # Safe to merge starts - do the merge
186- aug_graph .merge_to_set (existing_coll , candidate )
259+ self . aug_graph .merge_to_set (existing_coll , candidate )
187260
188261 # Step 3: Check and merge waits
189262 existing_wait = self .collective_info [existing_coll ].wait_node
190- candidate_wait = candidate_info .wait_node
191- # TODO - as above, limit search by idx
192- if aug_graph .has_path (existing_wait , candidate_wait ) or aug_graph .has_path (
193- candidate_wait , existing_wait
263+
264+ if self ._has_path (
265+ existing_wait , existing_bounds , candidate_wait , candidate_bounds
266+ ) or self ._has_path (
267+ candidate_wait , candidate_bounds , existing_wait , existing_bounds
194268 ):
195269 # Unmerge the start we just merged
196- aug_graph .unmerge_node (candidate )
270+ self . aug_graph .unmerge_node (candidate )
197271 return False
198272
199- aug_graph .merge_to_set (existing_wait , candidate_wait )
273+ self . aug_graph .merge_to_set (existing_wait , candidate_wait )
200274 return True
201275
202276 def _apply_bucket (
0 commit comments