Skip to content

[SymmMem] Add MemPool support to CUDA backend#169740

Closed
kwen2501 wants to merge 5 commits intogh/kwen2501/290/basefrom
gh/kwen2501/290/head
Closed

[SymmMem] Add MemPool support to CUDA backend#169740
kwen2501 wants to merge 5 commits intogh/kwen2501/290/basefrom
gh/kwen2501/290/head

Conversation

@kwen2501
Copy link
Collaborator

@kwen2501 kwen2501 commented Dec 6, 2025

Stack from ghstack (oldest at bottom):

[1/N] Extended rendezvous matching condition from exact address match to case where tensor falls in allocation range.

[2/N] Shifted all heavy stuff (involving cudaMalloc) from cudaSymmetricMemory to cudaPeerAllocInfo. The former now corresponds to a tensor, while the letter corresponds to an allocation. Tensors on the same allocation share the same cudaPeerAllocInfo.

[3/N] Added tests.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169740

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 3 Pending

As of commit 1082ad3 with merge base 04ae0e1 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kwen2501 added a commit that referenced this pull request Dec 6, 2025
[1/N][SymmMem] Reuse handle when ptr falls in allocation range

[2/N] Reuse peer allocation info


ghstack-source-id: a3b0e21
Pull-Request: #169740
@kwen2501 kwen2501 added release notes: distributed (symm_mem) release note label for symmetric memory module: symm_mem Issues and PRs of Symmetric Memory labels Dec 6, 2025
@eqy eqy requested a review from galv December 6, 2025 01:45
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Dec 6, 2025
[1/N][SymmMem] Reuse handle when ptr falls in allocation range

[2/N] Reuse peer allocation info


ghstack-source-id: 0637eec
Pull-Request: #169740
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Dec 8, 2025
[1/N][SymmMem] Reuse handle when ptr falls in allocation range

[2/N] Reuse peer allocation info

ghstack-source-id: bdc9039
Pull-Request: #169740
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Dec 8, 2025
[1/N][SymmMem] Reuse handle when ptr falls in allocation range

[2/N] Reuse peer allocation info

ghstack-source-id: b3c60b7
Pull-Request: #169740
@kwen2501 kwen2501 requested review from fduwjj, fegin and ngimel December 8, 2025 23:11
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Dec 9, 2025
[1/N][SymmMem] Reuse handle when ptr falls in allocation range

[2/N] Reuse peer allocation info

ghstack-source-id: 2ecdb7f
Pull-Request: #169740
}

/* Search for a block that covers the given ptr, and write back the offset to
* the base ptr; error out if not found */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment seems wrong, returns nullptr if not found, not error

@kwen2501
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 10, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #170008

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@kwen2501
Copy link
Collaborator Author

@pytorchbot merge -f "RoCM failure comes from runner: Available diskspace is less than 30 percent. Not enough diskspace"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Dec 10, 2025
We'd like to default some flags for SymmetricMemory pools, e.g.
- use_on_oom=False
- no_split=True

to improve UX and to fortify safety.

We thus provide a wrapper with the above flags preset:
```
pool = torch.distributed._symmetric_memory.get_mem_pool(device)
```

Since these flags are internal to the wrapper, we also maintain the flexibility to vary it in the future.
Pull Request resolved: #170008
Approved by: https://github.com/ngimel
ghstack dependencies: #169739, #169740
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Dec 10, 2025
[1/N] Extended rendezvous matching condition from exact address match to case where tensor falls in allocation range.

[2/N] Shifted all heavy stuff (involving cudaMalloc) from `cudaSymmetricMemory` to `cudaPeerAllocInfo`. The former now corresponds to a tensor, while the letter corresponds to an allocation. Tensors on the same allocation share the same `cudaPeerAllocInfo`.

[3/N] Added tests.
Pull Request resolved: pytorch#169740
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#169739
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Dec 10, 2025
We'd like to default some flags for SymmetricMemory pools, e.g.
- use_on_oom=False
- no_split=True

to improve UX and to fortify safety.

We thus provide a wrapper with the above flags preset:
```
pool = torch.distributed._symmetric_memory.get_mem_pool(device)
```

Since these flags are internal to the wrapper, we also maintain the flexibility to vary it in the future.
Pull Request resolved: pytorch#170008
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#169739, pytorch#169740
daisyden pushed a commit to daisyden/pytorch that referenced this pull request Dec 11, 2025
We'd like to default some flags for SymmetricMemory pools, e.g.
- use_on_oom=False
- no_split=True

to improve UX and to fortify safety.

We thus provide a wrapper with the above flags preset:
```
pool = torch.distributed._symmetric_memory.get_mem_pool(device)
```

Since these flags are internal to the wrapper, we also maintain the flexibility to vary it in the future.
Pull Request resolved: pytorch#170008
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#169739, pytorch#169740
@github-actions github-actions bot deleted the gh/kwen2501/290/head branch January 10, 2026 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100-symm-mem ciflow/trunk Trigger trunk jobs on your pull request Merged module: symm_mem Issues and PRs of Symmetric Memory open source release notes: distributed (c10d) release notes category release notes: distributed (symm_mem) release note label for symmetric memory

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants