[ContextParallel] add process-time based Round-Robin load-balance to CP#163617
[ContextParallel] add process-time based Round-Robin load-balance to CP#163617XilunWu wants to merge 17 commits intogh/XilunWu/172/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163617
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ebb2dcd with merge base b54e466 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
| tasks_in_group, _ = torch.sort(tasks_in_group, dim=1) | ||
| return tasks_in_group | ||
|
|
||
| def _generate_indices(self, restore: bool = False) -> Tensor: |
There was a problem hiding this comment.
Very pretty :)
we dont have a DtoH sync here right?
There was a problem hiding this comment.
no, all operations are tensor ops and happen on block_mask's device.
…balance to CP" **Summary** The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/ Longest-processing-time-first with extra padding added for collectives. - Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for `flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). - Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`. **Test** `pytest test/distributed/tensor/test_attention.py` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" **Summary** The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/ Longest-processing-time-first with extra padding added for collectives. - Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for `flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). - Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`. **Test** `pytest test/distributed/tensor/test_attention.py` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" **Summary** The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/ Longest-processing-time-first with extra padding added for collectives. - Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for `flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). - Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`. **Test** `pytest test/distributed/tensor/test_attention.py` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" **Summary** The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/ Longest-processing-time-first with extra padding added for collectives. - Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for `flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). - Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`. **Test** `pytest test/distributed/tensor/test_attention.py` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
fegin
left a comment
There was a problem hiding this comment.
LGTM, please address the comments before merging the PR.
|
|
||
| Warning: | ||
| For Multi-Head Attention, we require the masks over the head dimension are identical | ||
| (i.e. `self.block_mask` must have shape (B, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len)). |
There was a problem hiding this comment.
We should add the check in __init__().
| non_sparse_kv_num_blocks = ( | ||
| kv_num_blocks + full_kv_num_blocks | ||
| if full_kv_num_blocks is not None | ||
| else kv_num_blocks | ||
| ) | ||
| B, H, Q = non_sparse_kv_num_blocks.shape | ||
| # requirement: the masking is identical across heads (i.e. H == 1 in BlockMask) | ||
| non_sparse_kv_num_blocks = non_sparse_kv_num_blocks.view(-1, Q) # (B, Q_BLK) | ||
|
|
||
| batch_ptrr = torch.vmap( | ||
| functools.partial( | ||
| _PTRRLoadBalancer.ptrr_scheduling, | ||
| group_size=self.world_size, | ||
| ) | ||
| ) | ||
| ptrr_indices = batch_ptrr( | ||
| non_sparse_kv_num_blocks | ||
| ) # (B, group_size, num_blks_in_group) | ||
| ptrr_indices = ptrr_indices.reshape(B, -1) # (B, num_blocks) | ||
|
|
||
| # NOTE: only support the case where the qkv block size are equal | ||
| q_blk_size, kv_blk_size = block_mask.BLOCK_SIZE | ||
| assert q_blk_size == kv_blk_size, ( | ||
| "for now only support q_blk_size == kv_blk_size" | ||
| ) | ||
|
|
||
| indices = torch.arange( | ||
| q_blk_size * ptrr_indices.size(1), device=ptrr_indices.device | ||
| ).view(-1, q_blk_size) # (NUM_BLOCKS, BLOCK_SIZE) | ||
| indices = indices[ptrr_indices].view(B, -1) # (B, qkv_size) | ||
|
|
||
| if restore: | ||
| indices = torch.vmap(torch.argsort)(indices) |
There was a problem hiding this comment.
I'm thinking that should we put the logic to a separate function? The main reason is that I am worried about the performance indication and am thinking if we should compile the code.
There was a problem hiding this comment.
Yes, simply add @torch.compile(fullgraph=True) to the function definition. But we'll need to remove "raise" and "assert" since they would break graph when hit.
…balance to CP" **Summary** The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/ Longest-processing-time-first with extra padding added for collectives. - Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for `flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). - Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`. **Test** `pytest test/distributed/tensor/test_attention.py` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…balance to CP" **Summary** The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/ Longest-processing-time-first with extra padding added for collectives. - Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for `flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). - Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`. **Test** `pytest test/distributed/tensor/test_attention.py` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
#163617 removes the if/else statement to check if the input buffers have the batch dimension. This PR fixes the issue and also adds a test. In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe. ghstack-source-id: e441157 Pull-Request: #165792
#163617 removes the if/else statement to check if the input buffers have the batch dimension. This PR fixes the issue and also adds a test. In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe. ghstack-source-id: 321161d Pull-Request: #165792
…165792) #163617 removes the if/else statement to check if the input buffers have the batch dimension. This PR fixes the issue and also adds a test. In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe. Pull Request resolved: #165792 Approved by: https://github.com/XilunWu
…CP (pytorch#163617) **Summary** The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in pytorch#161062 for implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/ Longest-processing-time-first with extra padding added for collectives. - Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for `flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). - Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`. **Test** `pytest test/distributed/tensor/test_attention.py` Pull Request resolved: pytorch#163617 Approved by: https://github.com/fegin
…ytorch#165792) pytorch#163617 removes the if/else statement to check if the input buffers have the batch dimension. This PR fixes the issue and also adds a test. In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe. Pull Request resolved: pytorch#165792 Approved by: https://github.com/XilunWu
…ytorch#165792) pytorch#163617 removes the if/else statement to check if the input buffers have the batch dimension. This PR fixes the issue and also adds a test. In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe. Pull Request resolved: pytorch#165792 Approved by: https://github.com/XilunWu
ghstack-source-id: 970f2cc Pull Request resolved: pytorch/pytorch#163617
Stack from ghstack (oldest at bottom):
Summary
The load-balancing problem can be modeled as identical-machines scheduling problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.
_LoadBalancerimplementation_PTRRLoadBalancerwhich is designed forflex_attention(). This load-balance strategy analyzes theBlockMasksparsity info and performRound-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order).
_context_parallel_buffersandcontext_parallel_unshardhandle batched load-balanceindex (previously it can only handle non-batched load-balance index), like in
create_cp_block_mask.Test
pytest test/distributed/tensor/test_attention.pycc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci