Skip to content

[MPS] Support large tensors in torch.cat#164416

Closed
kurtamohler wants to merge 2 commits intogh/kurtamohler/55/basefrom
gh/kurtamohler/55/head
Closed

[MPS] Support large tensors in torch.cat#164416
kurtamohler wants to merge 2 commits intogh/kurtamohler/55/basefrom
gh/kurtamohler/55/head

Conversation

@kurtamohler
Copy link
Collaborator

@kurtamohler kurtamohler commented Oct 1, 2025

Stack from ghstack (oldest at bottom):

Fixes #164415

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164416

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 20ff360 with merge base 24d69c5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kurtamohler added a commit that referenced this pull request Oct 1, 2025
Fixes #164415


ghstack-source-id: c35e34e
Pull-Request: #164416
@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Oct 1, 2025

has_large_tensor |= isTooLargeForMPSGraph(out);

if (has_large_tensor) {
Copy link
Collaborator Author

@kurtamohler kurtamohler Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to check whether the alternate implementation works correctly for smaller sizes as well, so I tried temporarily changing this condition to always be true. I ran python -m pytest test/test_mps.py -k output_match_cat and it passed.

But it might be a good idea to cover this in CI too. To do that, we could add a non-public python api (either a global flag or a non-public function) that forces calling the alternate impl even if the tensors are small. Then we could add a test in test_mps.py that runs the opinfo cases for cat using the alternate impl.

But idk, maybe that's overkill. Let me know if it seems like something we'd want to do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to use it unconditionally? I suspect MPSGraph construction overhead for small tensors is probably significant, and perf for medium sized tensors should be the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll make a follow-up PR that replaces the MPSGraph impl

Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably looks fine to me, though I think it would be good to have an implementation that is more perf-aware and could completely replace MPSGraph.
I guess to achieve that one needs to have fast-path kernel variants for storage-dense tensors and may be just one flavor that supports type-casts (by doing if condition rather than have all possible permutations of the kernels, see example of the copy kernel)


has_large_tensor |= isTooLargeForMPSGraph(out);

if (has_large_tensor) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to use it unconditionally? I suspect MPSGraph construction overhead for small tensors is probably significant, and perf for medium sized tensors should be the same.

[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Oct 9, 2025
Fixes #164415

ghstack-source-id: ec37cce
Pull-Request: #164416
@kurtamohler
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@kurtamohler
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
@github-actions github-actions bot deleted the gh/kurtamohler/55/head branch November 13, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants