Fix clamp broadcasting on MPS (Fixes #160734)#165058
Fix clamp broadcasting on MPS (Fixes #160734)#165058roei-shlezinger wants to merge 2 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165058
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 8054314 with merge base 6ea7791 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot rerun lintrunner-noclang-partial |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot --help |
PyTorchBot HelpMergeRevertRebaseLabelDr CIcherry-pick |
|
Hi! I’m seeing two CI failures that don’t seem related to my changes: the DTensor dropout test on CPU, and the lintrunner job stopping because conda isn’t available in the linter image. lintrunner -a passes locally for me. @kulinseth could you advise on the next steps? Thanks! |
dc5b816 to
8054314
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR fixes a bug where `torch.clamp` on MPS fails when min/max tensors have more dimensions than the input tensor. CPU already supports this broadcasting, but MPS raised a RuntimeError. Example of failing case before the fix: ```python x = torch.randn(2, 3, device="mps") min_t = torch.randn(1, 2, 3, device="mps") max_t = torch.randn(1, 2, 3, device="mps") torch.clamp(x, min=min_t, max=max_t) # RuntimeError ``` After this fix, MPS matches CPU behavior. Fixes pytorch#160734 Pull Request resolved: pytorch#165058 Approved by: https://github.com/malfet
This PR fixes a bug where
torch.clampon MPS fails when min/max tensors have more dimensions than the input tensor.CPU already supports this broadcasting, but MPS raised a RuntimeError.
Example of failing case before the fix:
After this fix, MPS matches CPU behavior.
Fixes #160734