Skip to content

[ContextParallel] add process-time based Round-Robin load-balance to CP#163617

Closed
XilunWu wants to merge 17 commits intogh/XilunWu/172/basefrom
gh/XilunWu/172/head
Closed

[ContextParallel] add process-time based Round-Robin load-balance to CP#163617
XilunWu wants to merge 17 commits intogh/XilunWu/172/basefrom
gh/XilunWu/172/head

Conversation

@XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Sep 23, 2025

Stack from ghstack (oldest at bottom):

Summary
The load-balancing problem can be modeled as identical-machines scheduling problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.

  • Added a new type of _LoadBalancer implementation _PTRRLoadBalancer which is designed for
    flex_attention(). This load-balance strategy analyzes the BlockMask sparsity info and perform
    Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order).
  • Make _context_parallel_buffers and context_parallel_unshard handle batched load-balance
    index (previously it can only handle non-batched load-balance index), like in create_cp_block_mask.

Test
pytest test/distributed/tensor/test_attention.py

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163617

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ebb2dcd with merge base b54e466 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 23, 2025
XilunWu added a commit that referenced this pull request Sep 23, 2025
@XilunWu XilunWu marked this pull request as draft September 23, 2025 06:58
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Sep 23, 2025
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Sep 24, 2025
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Sep 24, 2025
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Sep 24, 2025
todo: fix unshard problem ; add batch values to tests

ghstack-source-id: 2bb794e
Pull Request resolved: #163617
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
@XilunWu XilunWu requested a review from fegin October 7, 2025 08:07
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
@XilunWu XilunWu requested a review from drisspg October 7, 2025 08:07
…balance to CP"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
@XilunWu XilunWu marked this pull request as ready for review October 7, 2025 17:40
tasks_in_group, _ = torch.sort(tasks_in_group, dim=1)
return tasks_in_group

def _generate_indices(self, restore: bool = False) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very pretty :)

we dont have a DtoH sync here right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, all operations are tensor ops and happen on block_mask's device.

…balance to CP"


**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.

- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). 
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.

**Test**
`pytest test/distributed/tensor/test_attention.py`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
@XilunWu XilunWu requested a review from drisspg October 7, 2025 22:48
…balance to CP"


**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.

- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). 
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.

**Test**
`pytest test/distributed/tensor/test_attention.py`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
…balance to CP"


**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.

- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). 
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.

**Test**
`pytest test/distributed/tensor/test_attention.py`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
@XilunWu XilunWu mentioned this pull request Oct 11, 2025
…balance to CP"


**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.

- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). 
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.

**Test**
`pytest test/distributed/tensor/test_attention.py`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please address the comments before merging the PR.


Warning:
For Multi-Head Attention, we require the masks over the head dimension are identical
(i.e. `self.block_mask` must have shape (B, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add the check in __init__().

Comment on lines +442 to +474
non_sparse_kv_num_blocks = (
kv_num_blocks + full_kv_num_blocks
if full_kv_num_blocks is not None
else kv_num_blocks
)
B, H, Q = non_sparse_kv_num_blocks.shape
# requirement: the masking is identical across heads (i.e. H == 1 in BlockMask)
non_sparse_kv_num_blocks = non_sparse_kv_num_blocks.view(-1, Q) # (B, Q_BLK)

batch_ptrr = torch.vmap(
functools.partial(
_PTRRLoadBalancer.ptrr_scheduling,
group_size=self.world_size,
)
)
ptrr_indices = batch_ptrr(
non_sparse_kv_num_blocks
) # (B, group_size, num_blks_in_group)
ptrr_indices = ptrr_indices.reshape(B, -1) # (B, num_blocks)

# NOTE: only support the case where the qkv block size are equal
q_blk_size, kv_blk_size = block_mask.BLOCK_SIZE
assert q_blk_size == kv_blk_size, (
"for now only support q_blk_size == kv_blk_size"
)

indices = torch.arange(
q_blk_size * ptrr_indices.size(1), device=ptrr_indices.device
).view(-1, q_blk_size) # (NUM_BLOCKS, BLOCK_SIZE)
indices = indices[ptrr_indices].view(B, -1) # (B, qkv_size)

if restore:
indices = torch.vmap(torch.argsort)(indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking that should we put the logic to a separate function? The main reason is that I am worried about the performance indication and am thinking if we should compile the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, simply add @torch.compile(fullgraph=True) to the function definition. But we'll need to remove "raise" and "assert" since they would break graph when hit.

…balance to CP"


**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.

- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). 
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.

**Test**
`pytest test/distributed/tensor/test_attention.py`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
…balance to CP"


**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.

- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). 
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.

**Test**
`pytest test/distributed/tensor/test_attention.py`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
@XilunWu
Copy link
Contributor Author

XilunWu commented Oct 15, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 15, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

fegin added a commit that referenced this pull request Oct 17, 2025
#163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.
ghstack-source-id: e441157
Pull-Request: #165792
fegin added a commit that referenced this pull request Oct 18, 2025
#163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.
ghstack-source-id: 321161d
Pull-Request: #165792
pytorchmergebot pushed a commit that referenced this pull request Oct 18, 2025
…165792)

#163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.

Pull Request resolved: #165792
Approved by: https://github.com/XilunWu
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…CP (pytorch#163617)

**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in pytorch#161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.

- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order).
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.

**Test**
`pytest test/distributed/tensor/test_attention.py`

Pull Request resolved: pytorch#163617
Approved by: https://github.com/fegin
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…ytorch#165792)

pytorch#163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.

Pull Request resolved: pytorch#165792
Approved by: https://github.com/XilunWu
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
…ytorch#165792)

pytorch#163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.

Pull Request resolved: pytorch#165792
Approved by: https://github.com/XilunWu
@github-actions github-actions bot deleted the gh/XilunWu/172/head branch November 16, 2025 02:21
Khanaksahu pushed a commit to Khanaksahu/pytorch-fork that referenced this pull request Nov 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: context parallel PyTorch Context Parallel oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: context parallel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants