[MPS] Support large tensors in torch.cat#164416
[MPS] Support large tensors in torch.cat#164416kurtamohler wants to merge 2 commits intogh/kurtamohler/55/basefrom
torch.cat#164416Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164416
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 20ff360 with merge base 24d69c5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| has_large_tensor |= isTooLargeForMPSGraph(out); | ||
|
|
||
| if (has_large_tensor) { |
There was a problem hiding this comment.
I wanted to check whether the alternate implementation works correctly for smaller sizes as well, so I tried temporarily changing this condition to always be true. I ran python -m pytest test/test_mps.py -k output_match_cat and it passed.
But it might be a good idea to cover this in CI too. To do that, we could add a non-public python api (either a global flag or a non-public function) that forces calling the alternate impl even if the tensors are small. Then we could add a test in test_mps.py that runs the opinfo cases for cat using the alternate impl.
But idk, maybe that's overkill. Let me know if it seems like something we'd want to do
There was a problem hiding this comment.
Is there a reason not to use it unconditionally? I suspect MPSGraph construction overhead for small tensors is probably significant, and perf for medium sized tensors should be the same.
There was a problem hiding this comment.
Sounds good, I'll make a follow-up PR that replaces the MPSGraph impl
malfet
left a comment
There was a problem hiding this comment.
Probably looks fine to me, though I think it would be good to have an implementation that is more perf-aware and could completely replace MPSGraph.
I guess to achieve that one needs to have fast-path kernel variants for storage-dense tensors and may be just one flavor that supports type-casts (by doing if condition rather than have all possible permutations of the kernels, see example of the copy kernel)
|
|
||
| has_large_tensor |= isTooLargeForMPSGraph(out); | ||
|
|
||
| if (has_large_tensor) { |
There was a problem hiding this comment.
Is there a reason not to use it unconditionally? I suspect MPSGraph construction overhead for small tensors is probably significant, and perf for medium sized tensors should be the same.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#164415 Pull Request resolved: pytorch#164416 Approved by: https://github.com/malfet
Stack from ghstack (oldest at bottom):
torch.cat#164416Fixes #164415