[symm_mem] Added a wait for signal and put signal for one side API#159837
[symm_mem] Added a wait for signal and put signal for one side API#159837fduwjj wants to merge 6 commits intogh/fduwjj/176/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159837
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 89b4650 with merge base fde929c ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ne side API" cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
|
nice :) |
| } | ||
|
|
||
| void nvshmem_put_with_signal(at::Tensor& tensor, int64_t peer) { | ||
| auto hdl = c10d::symmetric_memory::rendezvous(tensor, "0"); |
There was a problem hiding this comment.
What's "0" in this case. Also are we expected to call rendezvous amongst every rank in the group? Or just the ranks that get put/get-ing?
There was a problem hiding this comment.
In this case probably all ranks?
There was a problem hiding this comment.
"0" means global group. It is a temporary setting that can go wrong if the group is not actually global.
We need handle to remember the group which it has rendezvoused on.
…ne side API" cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…ne side API" cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
| } | ||
|
|
||
| void nvshmem_put_with_signal(at::Tensor& tensor, int64_t peer) { | ||
| auto hdl = c10d::symmetric_memory::rendezvous(tensor, "0"); |
There was a problem hiding this comment.
"0" means global group. It is a temporary setting that can go wrong if the group is not actually global.
We need handle to remember the group which it has rendezvoused on.
|
|
||
| c10::cuda::CUDAGuard guard(tensor.device()); | ||
| auto stream = at::cuda::getCurrentCUDAStream(); | ||
| nvshmemx_putmem_signal_on_stream(buffer_ptr, tensor.data_ptr(), buffer_size, static_cast<uint64_t*>(signal_ptr), NVSHMEM_SIGNAL_SET, 1, peer, stream); |
There was a problem hiding this comment.
Here the dst can be tensor.data_ptr() too. (A reminder for myself to refactor the whole file after we land MemPool support.
…ne side API" cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
…ne side API" cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…#159837) Pull Request resolved: #159837 Approved by: https://github.com/kwen2501
…pytorch#159837) Pull Request resolved: pytorch#159837 Approved by: https://github.com/kwen2501
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci