Skip to content

multimem reduce#164517

Closed
kwen2501 wants to merge 2 commits intogh/kwen2501/272/basefrom
gh/kwen2501/272/head
Closed

multimem reduce#164517
kwen2501 wants to merge 2 commits intogh/kwen2501/272/basefrom
gh/kwen2501/272/head

Conversation

@kwen2501
Copy link
Collaborator

@kwen2501 kwen2501 commented Oct 2, 2025

Stack from ghstack (oldest at bottom):

Modified multimem_one_shot_all_reduce_out function to accept a root argument, making it a multimem_reduce op.

The original multimem_one_shot_all_reduce op becomes a caller of the multimem_reduce, with each rank providing its own rank id as root.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164517

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 2f475da with merge base a707042 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kwen2501 added a commit that referenced this pull request Oct 2, 2025
ghstack-source-id: d04f156
Pull-Request: #164517
@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Oct 2, 2025
@kwen2501 kwen2501 requested review from fduwjj, fegin and ngimel October 2, 2025 22:39
@kwen2501 kwen2501 added the release notes: distributed (symm_mem) release note label for symmetric memory label Oct 2, 2025
@kwen2501
Copy link
Collaborator Author

kwen2501 commented Oct 3, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 3, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@shunting314
Copy link
Contributor

Do we drop the support for cuda version before 12.3?

Since otherwise, something like:

+at::Tensor multimem_one_shot_reduce_out(
+    const at::Tensor& input,
+    std::string reduce_op,
+    int64_t root,
+    std::string group_name,
+    at::Tensor out) {
+  TORCH_CHECK(false, "multimem_one_shot_reduce_out: requires CUDA 12.3+.");
+  return input;
+}

need to be added after #elif defined(CUDART_VERSION) && CUDART_VERSION < 12030
to be able to build

@facebook-github-bot
Copy link
Contributor

@pytorchbot revert -m="Diff reverted internally" -c="ghfirst"

This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).)

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Oct 7, 2025
This reverts commit d1cbb74.

Reverted #164517 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](#164517 (comment)))
@pytorchmergebot
Copy link
Collaborator

@kwen2501 your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Oct 7, 2025
@kwen2501
Copy link
Collaborator Author

kwen2501 commented Oct 7, 2025

@shunting314 The code you cited is merely host-side code, without fancy CUDA APIs. Which part is not buildable? Can you share your error log?

@shunting314
Copy link
Contributor

I don't have the log available. But it says multimem_one_shot_reduce_out is not defined since the definition is in a conditional block guarded by the CUDA 12.3 check.

@kwen2501
Copy link
Collaborator Author

kwen2501 commented Oct 8, 2025

I see, thank you @shunting314

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Oct 8, 2025
ghstack-source-id: e768757
Pull-Request: #164517
@kwen2501
Copy link
Collaborator Author

kwen2501 commented Oct 8, 2025

Added the temporary workaround as suggested.
Longer term, we should remove all those false impl's.

@kwen2501
Copy link
Collaborator Author

kwen2501 commented Oct 8, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Modified `multimem_one_shot_all_reduce_out` function to accept a `root` argument, making it a `multimem_reduce` op.

The original `multimem_one_shot_all_reduce` op becomes a caller of the `multimem_reduce`, with each rank providing its own rank id as root.

Pull Request resolved: pytorch#164517
Approved by: https://github.com/ngimel
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
This reverts commit d1cbb74.

Reverted pytorch#164517 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#164517 (comment)))
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Modified `multimem_one_shot_all_reduce_out` function to accept a `root` argument, making it a `multimem_reduce` op.

The original `multimem_one_shot_all_reduce` op becomes a caller of the `multimem_reduce`, with each rank providing its own rank id as root.

Pull Request resolved: pytorch#164517
Approved by: https://github.com/ngimel
@github-actions github-actions bot deleted the gh/kwen2501/272/head branch November 8, 2025 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/h100-symm-mem ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category release notes: distributed (symm_mem) release note label for symmetric memory Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants