-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[SR] Eliminate extra permute ops before aten::sum
#74481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This diff fixes an interesting performance issue related to `permute_copy`. We see this pattern frequently: ``` y = torch.permute(x, (0, 2, 1)) z = torch.sum(y, dim=-1) ``` With copy variants off, we get a strided output from `permute`, and we hit this (faster) kernel in `sum`: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589 But with copy variants on, we get a contiguous output from `permute_copy`, which causes us to hit the slower reduction: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597 But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on: ``` z = torch.sum(x, dim=1) ``` Differential Revision: [D34992319](https://our.internmc.facebook.com/intern/diff/D34992319/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34992319/)! [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 9b55df5 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
This diff fixes an interesting performance issue related to `permute_copy`. We see this pattern frequently: ``` y = torch.permute(x, (0, 2, 1)) z = torch.sum(y, dim=-1) ``` With copy variants off, we get a strided output from `permute`, and we hit this (faster) kernel in `sum`: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589 But with copy variants on, we get a contiguous output from `permute_copy`, which causes us to hit the slower reduction: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597 But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on: ``` z = torch.sum(x, dim=1) ``` Differential Revision: [D34992319](https://our.internmc.facebook.com/intern/diff/D34992319/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34992319/)! [ghstack-poisoned]
Pull Request resolved: #74481 This diff fixes an interesting performance issue related to `permute_copy`. We see this pattern frequently: ``` y = torch.permute(x, (0, 2, 1)) z = torch.sum(y, dim=-1) ``` With copy variants off, we get a strided output from `permute`, and we hit this (faster) kernel in `sum`: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589 But with copy variants on, we get a contiguous output from `permute_copy`, which causes us to hit the slower reduction: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597 But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on: ``` z = torch.sum(x, dim=1) ``` ghstack-source-id: 151822137 Differential Revision: [D34992319](https://our.internmc.facebook.com/intern/diff/D34992319/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34992319/)!
This diff fixes an interesting performance issue related to `permute_copy`. We see this pattern frequently: ``` y = torch.permute(x, (0, 2, 1)) z = torch.sum(y, dim=-1) ``` With copy variants off, we get a strided output from `permute`, and we hit this (faster) kernel in `sum`: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589 But with copy variants on, we get a contiguous output from `permute_copy`, which causes us to hit the slower reduction: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597 But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on: ``` z = torch.sum(x, dim=1) ``` Differential Revision: [D34992319](https://our.internmc.facebook.com/intern/diff/D34992319/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34992319/)! [ghstack-poisoned]
Pull Request resolved: #74481 This diff fixes an interesting performance issue related to `permute_copy`. We see this pattern frequently: ``` y = torch.permute(x, (0, 2, 1)) z = torch.sum(y, dim=-1) ``` With copy variants off, we get a strided output from `permute`, and we hit this (faster) kernel in `sum`: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589 But with copy variants on, we get a contiguous output from `permute_copy`, which causes us to hit the slower reduction: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597 But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on: ``` z = torch.sum(x, dim=1) ``` ghstack-source-id: 152003888 Differential Revision: [D34992319](https://our.internmc.facebook.com/intern/diff/D34992319/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34992319/)!
Summary: Pull Request resolved: #74481 This diff fixes an interesting performance issue related to `permute_copy`. We see this pattern frequently: ``` y = torch.permute(x, (0, 2, 1)) z = torch.sum(y, dim=-1) ``` With copy variants off, we get a strided output from `permute`, and we hit this (faster) kernel in `sum`: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589 But with copy variants on, we get a contiguous output from `permute_copy`, which causes us to hit the slower reduction: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597 But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on: ``` z = torch.sum(x, dim=1) ``` ghstack-source-id: 152003888 Reviewed By: navahgar Differential Revision: D34992319 fbshipit-source-id: 0baf493708ee2180c899814a954d220d88ba1d4f
|
Hey @mikeiovine. |
Summary: Pull Request resolved: #74481 This diff fixes an interesting performance issue related to `permute_copy`. We see this pattern frequently: ``` y = torch.permute(x, (0, 2, 1)) z = torch.sum(y, dim=-1) ``` With copy variants off, we get a strided output from `permute`, and we hit this (faster) kernel in `sum`: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589 But with copy variants on, we get a contiguous output from `permute_copy`, which causes us to hit the slower reduction: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597 But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on: ``` z = torch.sum(x, dim=1) ``` ghstack-source-id: 152003888 Reviewed By: navahgar Differential Revision: D34992319 fbshipit-source-id: 0baf493708ee2180c899814a954d220d88ba1d4f (cherry picked from commit 797b6be)
Stack from ghstack (oldest at bottom):
aten::sum#74481This diff fixes an interesting performance issue related to
permute_copy.We see this pattern frequently:
With copy variants off, we get a strided output from
permute, and we hit this (faster) kernel insum: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L589But with copy variants on, we get a contiguous output from
permute_copy, which causes us to hit the slower reduction:https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/SumKernel.cpp#L597
But the permute is actually unnecessary, we can just statically turn the graph into this to ensure that the fast kernel is hit with copy variants on:
Differential Revision: D34992319
NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!