Skip to content

Test Copy Engine All-Gather#170265

Closed
kwen2501 wants to merge 2 commits intogh/kwen2501/293/basefrom
gh/kwen2501/293/head
Closed

Test Copy Engine All-Gather#170265
kwen2501 wants to merge 2 commits intogh/kwen2501/293/basefrom
gh/kwen2501/293/head

Conversation

@kwen2501
Copy link
Collaborator

@kwen2501 kwen2501 commented Dec 12, 2025

Stack from ghstack (oldest at bottom):

NCCL 2.28 added Copy Engine (CE) support.

Condition:

  • Tensors be symmetrically registered (e.g. coming from symm_mem.empty)
  • NCCL_CTA_POLICY_ZERO be passed to ncclConfig or env var NCCL_CTA_POLICY=2

Confirmed use of CE via profile:
Screenshot 2025-12-11 at 4 47 50 PM

(First kernel is from regular all-gather, second kernel is from all-gather on tensors that have been window registered)

Caveat:
As of 2.28.9, CE collectives cannot be run on default stream, so we are testing it with async_op=True or with a side stream.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/170265

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit be0a4a1 with merge base eed7d91 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Dec 12, 2025
kwen2501 added a commit that referenced this pull request Dec 12, 2025
ghstack-source-id: 8e166b7
Pull-Request: #170265
@kwen2501 kwen2501 added release notes: distributed (symm_mem) release note label for symmetric memory module: symm_mem Issues and PRs of Symmetric Memory and removed topic: not user facing topic category labels Dec 12, 2025
@kwen2501
Copy link
Collaborator Author

cc @weifengpy for potential use in FSDP for reducing compute-comm contention.

@kwen2501 kwen2501 requested review from fduwjj and ngimel December 12, 2025 01:11
[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Dec 12, 2025
kwen2501 added a commit that referenced this pull request Dec 12, 2025
ghstack-source-id: 3b5a426
Pull-Request: #170265
@kwen2501
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 12, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@weifengpy
Copy link
Contributor

@dzmitry-huba

vishalgoyal316 pushed a commit to vishalgoyal316/pytorch that referenced this pull request Dec 17, 2025
NCCL 2.28 added Copy Engine (CE) support.

Condition:
- Tensors be symmetrically registered (e.g. coming from `symm_mem.empty`)
- `NCCL_CTA_POLICY_ZERO` be passed to `ncclConfig` or env var `NCCL_CTA_POLICY=2`

Confirmed use of CE via profile:
<img width="988" height="132" alt="Screenshot 2025-12-11 at 4 47 50 PM" src="https://github.com/user-attachments/assets/2077d88b-34d9-4155-b323-646cab904e68" />

(First kernel is from regular all-gather, second kernel is from all-gather on tensors that have been window registered)

Caveat:
As of 2.28.9, CE collectives cannot be run on default stream, so we are testing it with `async_op=True` or with a side stream.
Pull Request resolved: pytorch#170265
Approved by: https://github.com/fduwjj
@Microve
Copy link
Contributor

Microve commented Dec 22, 2025

Wonder whether Copy Engine All-Gather works with torch.compile?

@kwen2501
Copy link
Collaborator Author

kwen2501 commented Jan 6, 2026

@Microve There are two scenarios:

(1) If the eager-mode program has been rewritten to enable CE, i.e. the user has been using symmetric memory:
Then yes, it would work with torch.compile. (The algorithm selection is internal to NCCL thus not visible by torch.compile)

(2) If the eager-mode program is written without symmetric memory:
There is an opportunity for torch.compile to convert the program to use SymmMem thus gaining an optimization.

cc @eellison @eee4017

krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
NCCL 2.28 added Copy Engine (CE) support.

Condition:
- Tensors be symmetrically registered (e.g. coming from `symm_mem.empty`)
- `NCCL_CTA_POLICY_ZERO` be passed to `ncclConfig` or env var `NCCL_CTA_POLICY=2`

Confirmed use of CE via profile:
<img width="988" height="132" alt="Screenshot 2025-12-11 at 4 47 50 PM" src="https://github.com/user-attachments/assets/2077d88b-34d9-4155-b323-646cab904e68" />

(First kernel is from regular all-gather, second kernel is from all-gather on tensors that have been window registered)

Caveat:
As of 2.28.9, CE collectives cannot be run on default stream, so we are testing it with `async_op=True` or with a side stream.
Pull Request resolved: pytorch#170265
Approved by: https://github.com/fduwjj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: symm_mem Issues and PRs of Symmetric Memory open source release notes: distributed (symm_mem) release note label for symmetric memory topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants