Skip to content

[CP] Replace context_parallel context manager with functional APIs#164500

Closed
fegin wants to merge 15 commits intogh/fegin/326/basefrom
gh/fegin/326/head
Closed

[CP] Replace context_parallel context manager with functional APIs#164500
fegin wants to merge 15 commits intogh/fegin/326/basefrom
gh/fegin/326/head

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Oct 2, 2025

Stack from ghstack (oldest at bottom):

context_parallel() being a context manager has annoyed users. Now that we plan to redesign CP's UX to explicitly ask users to:

  1. Wrap the attention op into an nn.Module
  2. Lift any buffers that are not sequence agnostic to input

We can replace context_parallel() with two functional APIs: _context_parallel_shard and _enable_context_parallel_dispatcher.

cc @H-Huang @awgu @wanchaol @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164500

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 4bda542 with merge base d41aa18 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
@fegin fegin requested a review from XilunWu October 3, 2025 07:05
@fegin fegin changed the title [CP] Implement _context_parallel_shard function to replace context_parallel context manager [CP] Replace context_parallel context manager with functional APIs Oct 3, 2025
fegin added 3 commits October 3, 2025 00:17
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nits, stamp to unblock us and we can come back later to address them.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow feel we don't need _enable_context_parallel_dispatcher as user-facing API.

cp_q, cp_k, cp_v = _context_parallel_shard(
mesh, (cp_q, cp_k, cp_v), (seq_dim,) * 3
)
_enable_context_parallel_dispatcher(seq_dim, mesh)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks for now this is only for sdpa but not FlexAttention.

Can we put them in sdpa_cp.sdpa_input_fn, sdpa_cp.sdpa_output_fn? It's also safer that way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is that we cannot put _disable_context_parallel_dispatcher in sdpa_output_fn because we want to wait until the backward so that we can disable it. iirc, if we do something in the backward hook, it may cause graph break? I'm not sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, keep it for now. We need to think about a better way to integrate DTensor with SDPA

fegin added 4 commits October 6, 2025 22:09
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
fegin added 2 commits October 9, 2025 13:30
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2025
fegin added 4 commits October 9, 2025 22:49
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin requested a review from tianyu-l October 12, 2025 05:31
@pytorch pytorch deleted a comment from pytorch-bot bot Oct 12, 2025
@fegin fegin dismissed tianyu-l’s stale review October 12, 2025 05:32

Discussed offline, keep it for now. We need to think about a better way to integrate DTensor with SDPA

@pytorch pytorch deleted a comment from pytorch-bot bot Oct 12, 2025
@fegin
Copy link
Contributor Author

fegin commented Oct 12, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@fegin
Copy link
Contributor Author

fegin commented Oct 13, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 13, 2025
The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter.  While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd.

For the next step, we should explore how to interpolate the required communication based on the information from BlockMask.

Pull Request resolved: #163185
Approved by: https://github.com/XilunWu
ghstack dependencies: #162542, #164500
pytorchmergebot pushed a commit that referenced this pull request Oct 14, 2025
…5039)

No logic change, just polish the docstrings, comments and remove unused variables

Pull Request resolved: #165039
Approved by: https://github.com/XilunWu
ghstack dependencies: #162542, #164500, #163185
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 15, 2025
…orch#165039)

No logic change, just polish the docstrings, comments and remove unused variables

Pull Request resolved: pytorch#165039
Approved by: https://github.com/XilunWu
ghstack dependencies: pytorch#162542, pytorch#164500, pytorch#163185
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…ytorch#164500)

`context_parallel()` being a context manager has annoyed users. Now that we plan to redesign CP's UX to explicitly ask users to:

1. Wrap the attention op into an `nn.Module`
2. Lift any buffers that are not sequence agnostic to input

We can replace `context_parallel()` with two functional APIs: `_context_parallel_shard` and `_enable_context_parallel_dispatcher`.

Pull Request resolved: pytorch#164500
Approved by: https://github.com/XilunWu
ghstack dependencies: pytorch#162542
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…h#163185)

The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter.  While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd.

For the next step, we should explore how to interpolate the required communication based on the information from BlockMask.

Pull Request resolved: pytorch#163185
Approved by: https://github.com/XilunWu
ghstack dependencies: pytorch#162542, pytorch#164500
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…orch#165039)

No logic change, just polish the docstrings, comments and remove unused variables

Pull Request resolved: pytorch#165039
Approved by: https://github.com/XilunWu
ghstack dependencies: pytorch#162542, pytorch#164500, pytorch#163185
@github-actions github-actions bot deleted the gh/fegin/326/head branch November 13, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: context parallel PyTorch Context Parallel oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: context parallel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants