Commit ebb2dcd
committed
Update on "[ContextParallel] add process-time based Round-Robin load-balance to CP"
**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.
- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order).
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.
**Test**
`pytest test/distributed/tensor/test_attention.py`
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci
[ghstack-poisoned]File tree
269 files changed
+5589
-1987
lines changed- .ci/manywheel
- .github/ci_commit_pins
- aten/src/ATen
- cuda/tunable
- functorch
- native
- cuda
- mps/operations
- benchmarks/dynamo
- ci_expected_accuracy
- rocm
- c10/util
- docs/source
- test
- distributed
- elastic/multiprocessing
- bin
- pipelining
- tensor
- dynamo_expected_failures
- dynamo
- cpython/3_13
- expect
- export
- functorch
- fx
- inductor
- profiler
- tools/stats
- torch
- _C
- _dynamo
- variables
- _functorch/_aot_autograd
- _inductor
- codegen
- fx_passes
- compiler
- csrc
- acc
- api/include/torch/nn
- functional
- modules/container
- autograd
- cuda
- distributed
- c10d
- rpc
- inductor
- aoti_eager
- aoti_runtime
- jit
- api
- codegen/cuda
- frontend
- ir
- mobile
- passes
- onnx
- quantization
- python
- runtime
- static
- tensorexpr
- operators
- lazy/core
- monitor
- profiler
- orchestration
- python
- standalone
- utils
- cuda
- distributed
- _local_tensor
- pipelining
- tensor
- debug
- experimental
- fx
- experimental
- migrate_gradual_types
- unification/multipledispatch
- passes
- infra
- utils
- lib/libshm
- nativert
- common
- detail
- executor
- graph
- nn
- attention
- testing/_internal
- optests
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
269 files changed
+5589
-1987
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
187 | 187 | | |
188 | 188 | | |
189 | 189 | | |
190 | | - | |
191 | 190 | | |
192 | 191 | | |
193 | 192 | | |
194 | 193 | | |
195 | 194 | | |
196 | 195 | | |
197 | | - | |
198 | 196 | | |
199 | 197 | | |
200 | 198 | | |
201 | 199 | | |
202 | 200 | | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
203 | 206 | | |
204 | 207 | | |
205 | 208 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
| 1 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
| 1 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
256 | 256 | | |
257 | 257 | | |
258 | 258 | | |
| 259 | + | |
259 | 260 | | |
260 | 261 | | |
261 | 262 | | |
| |||
292 | 293 | | |
293 | 294 | | |
294 | 295 | | |
295 | | - | |
| 296 | + | |
296 | 297 | | |
297 | 298 | | |
298 | 299 | | |
299 | 300 | | |
300 | 301 | | |
301 | 302 | | |
302 | | - | |
303 | | - | |
304 | | - | |
305 | | - | |
306 | | - | |
307 | | - | |
308 | | - | |
309 | | - | |
310 | | - | |
311 | | - | |
312 | | - | |
313 | | - | |
314 | | - | |
315 | | - | |
316 | | - | |
317 | | - | |
318 | | - | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | | - | |
327 | | - | |
328 | | - | |
329 | | - | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
341 | | - | |
342 | | - | |
343 | | - | |
344 | | - | |
345 | | - | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
346 | 330 | | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
347 | 354 | | |
348 | 355 | | |
349 | 356 | | |
| |||
692 | 699 | | |
693 | 700 | | |
694 | 701 | | |
695 | | - | |
696 | | - | |
697 | | - | |
698 | | - | |
699 | | - | |
700 | | - | |
701 | 702 | | |
702 | 703 | | |
703 | 704 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
| |||
150 | 151 | | |
151 | 152 | | |
152 | 153 | | |
| 154 | + | |
153 | 155 | | |
154 | 156 | | |
155 | 157 | | |
| |||
244 | 246 | | |
245 | 247 | | |
246 | 248 | | |
247 | | - | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
248 | 255 | | |
249 | | - | |
250 | 256 | | |
251 | 257 | | |
252 | 258 | | |
253 | 259 | | |
254 | | - | |
255 | | - | |
256 | | - | |
257 | | - | |
258 | | - | |
259 | | - | |
260 | | - | |
261 | | - | |
262 | | - | |
263 | | - | |
264 | | - | |
265 | | - | |
266 | | - | |
267 | | - | |
268 | | - | |
269 | | - | |
270 | | - | |
271 | | - | |
272 | 260 | | |
273 | | - | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
274 | 268 | | |
275 | 269 | | |
276 | 270 | | |
| |||
355 | 349 | | |
356 | 350 | | |
357 | 351 | | |
| 352 | + | |
| 353 | + | |
358 | 354 | | |
359 | | - | |
| 355 | + | |
360 | 356 | | |
361 | 357 | | |
362 | 358 | | |
| |||
449 | 445 | | |
450 | 446 | | |
451 | 447 | | |
| 448 | + | |
| 449 | + | |
452 | 450 | | |
453 | | - | |
| 451 | + | |
454 | 452 | | |
455 | 453 | | |
456 | 454 | | |
| |||
546 | 544 | | |
547 | 545 | | |
548 | 546 | | |
| 547 | + | |
| 548 | + | |
549 | 549 | | |
550 | | - | |
| 550 | + | |
551 | 551 | | |
552 | 552 | | |
553 | 553 | | |
| |||
663 | 663 | | |
664 | 664 | | |
665 | 665 | | |
666 | | - | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
667 | 669 | | |
668 | 670 | | |
669 | 671 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
145 | 145 | | |
146 | 146 | | |
147 | 147 | | |
148 | | - | |
| 148 | + | |
149 | 149 | | |
150 | 150 | | |
151 | 151 | | |
| |||
173 | 173 | | |
174 | 174 | | |
175 | 175 | | |
| 176 | + | |
176 | 177 | | |
177 | 178 | | |
178 | | - | |
179 | | - | |
180 | 179 | | |
181 | 180 | | |
182 | 181 | | |
| |||
0 commit comments