Skip to content

[SymmMem] Fix put_signal + wait_until hang#163194

Closed
kwen2501 wants to merge 1 commit intogh/kwen2501/251/basefrom
gh/kwen2501/251/head
Closed

[SymmMem] Fix put_signal + wait_until hang#163194
kwen2501 wants to merge 1 commit intogh/kwen2501/251/basefrom
gh/kwen2501/251/head

Conversation

@kwen2501
Copy link
Collaborator

@kwen2501 kwen2501 commented Sep 17, 2025

Stack from ghstack (oldest at bottom):

The test used a wrong ptr to refer to remote address:

            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]

All three indices should be rank instead of peer because NVSHMEM APIs accept local address as input and perform translation internally. Without correct signal address, the peer would be waiting, thus hang.

Also adjusted the signature of nvshmem.putmem_signal_block to accept tensor instead of pointer.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163194

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d2f3060 with merge base 4840a1a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/h100-symm-mem oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 17, 2025
kwen2501 added a commit that referenced this pull request Sep 17, 2025
@kwen2501 kwen2501 added the release notes: distributed (symm_mem) release note label for symmetric memory label Sep 17, 2025
@kwen2501 kwen2501 requested review from fegin and ngimel September 17, 2025 21:20
@kwen2501
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 18, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

This PR (#163194) was merged in 80f8be9 but it is still open, likely due to a Github bug, so mergebot is closing it manually. If you think this is a mistake, please feel free to reopen and contact Dev Infra.

jeffkbkim pushed a commit to jeffkbkim/pytorch that referenced this pull request Sep 18, 2025
The test used a wrong ptr to refer to remote address:
```
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
```
All three indices should be `rank` instead of `peer` because NVSHMEM APIs accept local address as input and perform translation internally. Without correct signal address, the peer would be waiting, thus hang.

Also adjusted the signature of `nvshmem.putmem_signal_block` to accept tensor instead of pointer.

Pull Request resolved: pytorch#163194
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#163025, pytorch#163152
pytorchmergebot pushed a commit that referenced this pull request Sep 21, 2025
…3423)

### Issue
The previous `enable_triton` UI requires the user-defined Triton kernel have a "nvshmem" in its name.
If users did not do so, the kernel would miss the NVSHMEM init, and silently hit CUDA IMA.

The `@require_nvshmem` decorator eliminates the above name requirement (and the `enable_triton` call).

### Usage:
```
@requires_nvshmem
@triton.jit
def foo(...):
    ...

foo[(1, 1)](...)
```
It also remove the need of passing `extern_lib` to `foo` (handled by the decorator now).

Pull Request resolved: #163423
Approved by: https://github.com/ngimel
ghstack dependencies: #163025, #163152, #163194
@kwen2501
Copy link
Collaborator Author

@pytorchbot cherry-pick --onto release/2.9 --fixes #162934 -c critical

pytorchbot pushed a commit that referenced this pull request Sep 21, 2025
The test used a wrong ptr to refer to remote address:
```
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
```
All three indices should be `rank` instead of `peer` because NVSHMEM APIs accept local address as input and perform translation internally. Without correct signal address, the peer would be waiting, thus hang.

Also adjusted the signature of `nvshmem.putmem_signal_block` to accept tensor instead of pointer.

Pull Request resolved: #163194
Approved by: https://github.com/ngimel
ghstack dependencies: #163025, #163152

(cherry picked from commit 80f8be9)
@pytorchbot
Copy link
Collaborator

Cherry picking #163194

The cherry pick PR is at #163458 and it is linked with issue #162934. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
The test used a wrong ptr to refer to remote address:
```
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
```
All three indices should be `rank` instead of `peer` because NVSHMEM APIs accept local address as input and perform translation internally. Without correct signal address, the peer would be waiting, thus hang.

Also adjusted the signature of `nvshmem.putmem_signal_block` to accept tensor instead of pointer.

Pull Request resolved: pytorch#163194
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#163025, pytorch#163152
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…orch#163423)

### Issue
The previous `enable_triton` UI requires the user-defined Triton kernel have a "nvshmem" in its name.
If users did not do so, the kernel would miss the NVSHMEM init, and silently hit CUDA IMA.

The `@require_nvshmem` decorator eliminates the above name requirement (and the `enable_triton` call).

### Usage:
```
@requires_nvshmem
@triton.jit
def foo(...):
    ...

foo[(1, 1)](...)
```
It also remove the need of passing `extern_lib` to `foo` (handled by the decorator now).

Pull Request resolved: pytorch#163423
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#163025, pytorch#163152, pytorch#163194
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
The test used a wrong ptr to refer to remote address:
```
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
```
All three indices should be `rank` instead of `peer` because NVSHMEM APIs accept local address as input and perform translation internally. Without correct signal address, the peer would be waiting, thus hang.

Also adjusted the signature of `nvshmem.putmem_signal_block` to accept tensor instead of pointer.

Pull Request resolved: pytorch#163194
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#163025, pytorch#163152
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…orch#163423)

### Issue
The previous `enable_triton` UI requires the user-defined Triton kernel have a "nvshmem" in its name.
If users did not do so, the kernel would miss the NVSHMEM init, and silently hit CUDA IMA.

The `@require_nvshmem` decorator eliminates the above name requirement (and the `enable_triton` call).

### Usage:
```
@requires_nvshmem
@triton.jit
def foo(...):
    ...

foo[(1, 1)](...)
```
It also remove the need of passing `extern_lib` to `foo` (handled by the decorator now).

Pull Request resolved: pytorch#163423
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#163025, pytorch#163152, pytorch#163194
Camyll pushed a commit that referenced this pull request Sep 23, 2025
[SymmMem] Fix put_signal + wait_until hang (#163194)

The test used a wrong ptr to refer to remote address:
```
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
```
All three indices should be `rank` instead of `peer` because NVSHMEM APIs accept local address as input and perform translation internally. Without correct signal address, the peer would be waiting, thus hang.

Also adjusted the signature of `nvshmem.putmem_signal_block` to accept tensor instead of pointer.

Pull Request resolved: #163194
Approved by: https://github.com/ngimel
ghstack dependencies: #163025, #163152

(cherry picked from commit 80f8be9)

Co-authored-by: Ke Wen <kw2501@meta.com>
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
The test used a wrong ptr to refer to remote address:
```
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
```
All three indices should be `rank` instead of `peer` because NVSHMEM APIs accept local address as input and perform translation internally. Without correct signal address, the peer would be waiting, thus hang.

Also adjusted the signature of `nvshmem.putmem_signal_block` to accept tensor instead of pointer.

Pull Request resolved: pytorch#163194
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#163025, pytorch#163152
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…orch#163423)

### Issue
The previous `enable_triton` UI requires the user-defined Triton kernel have a "nvshmem" in its name.
If users did not do so, the kernel would miss the NVSHMEM init, and silently hit CUDA IMA.

The `@require_nvshmem` decorator eliminates the above name requirement (and the `enable_triton` call).

### Usage:
```
@requires_nvshmem
@triton.jit
def foo(...):
    ...

foo[(1, 1)](...)
```
It also remove the need of passing `extern_lib` to `foo` (handled by the decorator now).

Pull Request resolved: pytorch#163423
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#163025, pytorch#163152, pytorch#163194
@github-actions github-actions bot deleted the gh/kwen2501/251/head branch October 22, 2025 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100-symm-mem ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (symm_mem) release note label for symmetric memory

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants