[CP] Replace context_parallel context manager with functional APIs#164500
[CP] Replace context_parallel context manager with functional APIs#164500fegin wants to merge 15 commits intogh/fegin/326/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164500
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 4bda542 with merge base d41aa18 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
XilunWu
left a comment
There was a problem hiding this comment.
Some nits, stamp to unblock us and we can come back later to address them.
tianyu-l
left a comment
There was a problem hiding this comment.
I somehow feel we don't need _enable_context_parallel_dispatcher as user-facing API.
| cp_q, cp_k, cp_v = _context_parallel_shard( | ||
| mesh, (cp_q, cp_k, cp_v), (seq_dim,) * 3 | ||
| ) | ||
| _enable_context_parallel_dispatcher(seq_dim, mesh) |
There was a problem hiding this comment.
It looks for now this is only for sdpa but not FlexAttention.
Can we put them in sdpa_cp.sdpa_input_fn, sdpa_cp.sdpa_output_fn? It's also safer that way.
There was a problem hiding this comment.
I think the problem is that we cannot put _disable_context_parallel_dispatcher in sdpa_output_fn because we want to wait until the backward so that we can disable it. iirc, if we do something in the backward hook, it may cause graph break? I'm not sure.
There was a problem hiding this comment.
Discussed offline, keep it for now. We need to think about a better way to integrate DTensor with SDPA
Discussed offline, keep it for now. We need to think about a better way to integrate DTensor with SDPA
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter. While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd. For the next step, we should explore how to interpolate the required communication based on the information from BlockMask. Pull Request resolved: #163185 Approved by: https://github.com/XilunWu ghstack dependencies: #162542, #164500
…orch#165039) No logic change, just polish the docstrings, comments and remove unused variables Pull Request resolved: pytorch#165039 Approved by: https://github.com/XilunWu ghstack dependencies: pytorch#162542, pytorch#164500, pytorch#163185
…ytorch#164500) `context_parallel()` being a context manager has annoyed users. Now that we plan to redesign CP's UX to explicitly ask users to: 1. Wrap the attention op into an `nn.Module` 2. Lift any buffers that are not sequence agnostic to input We can replace `context_parallel()` with two functional APIs: `_context_parallel_shard` and `_enable_context_parallel_dispatcher`. Pull Request resolved: pytorch#164500 Approved by: https://github.com/XilunWu ghstack dependencies: pytorch#162542
…h#163185) The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter. While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd. For the next step, we should explore how to interpolate the required communication based on the information from BlockMask. Pull Request resolved: pytorch#163185 Approved by: https://github.com/XilunWu ghstack dependencies: pytorch#162542, pytorch#164500
…orch#165039) No logic change, just polish the docstrings, comments and remove unused variables Pull Request resolved: pytorch#165039 Approved by: https://github.com/XilunWu ghstack dependencies: pytorch#162542, pytorch#164500, pytorch#163185
Stack from ghstack (oldest at bottom):
context_parallel()being a context manager has annoyed users. Now that we plan to redesign CP's UX to explicitly ask users to:nn.ModuleWe can replace
context_parallel()with two functional APIs:_context_parallel_shardand_enable_context_parallel_dispatcher.cc @H-Huang @awgu @wanchaol @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci