-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[PyTorch][Static Runtime] out variant for _s_where #73438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)). This diff implements the `where` out variant for static runtime as follows: * Added an out variant for `_s_where` in PyTorch * Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`. Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)! [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit d7fba92 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)). This diff implements the `where` out variant for static runtime as follows: * Added an out variant for `_s_where` in PyTorch * Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`. Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)! ghstack-source-id: 149977789 Pull Request resolved: #73438
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)). This diff implements the `where` out variant for static runtime as follows: * Added an out variant for `_s_where` in PyTorch * Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`. Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)! [ghstack-poisoned]
Pull Request resolved: #73438 `aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)). This diff implements the `where` out variant for static runtime as follows: * Added an out variant for `_s_where` in PyTorch * Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`. ghstack-source-id: 150007304 Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!
|
|
||
| TORCH_META_FUNC(_s_where) (const Tensor& condition, const Tensor& self, const Tensor& other) { | ||
| TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype()); | ||
| Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not very good but I'm not sure there is necessarily a better way to do this. What is not good about it is that meta functions are not supposed to do any nontrivial compute, because they may get invoked as part of, e.g., just shape propagation (meta functions). If condition happens to be a meta tensor this will be relatively cheap, but something that is supposed to work is, for example, calling meta::_s_where(cpu_tensors...) and that will useless do some compute in this case.
What makes matters a little better is that TensorIterator::build (used everywhere here) DOES do some temporary allocations, and it's kind of hacked around by just checking if any input argument is meta, and then avoiding allocating temporaries if it's the case. So I'm willing to let this slide, but let's have a comment+issue about it.
cc @ysiraichi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll add a comment - it's also worth noting that this byte -> bool conversion is deprecated, so the conversion will not be needed soon hopefully
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not qualified for the static runtime bit, but everything else seems to be fine. This adds meta support for _s_where, do we get test coverage for free?
I'm not sure, how are meta functions usually tested? Would be happy to add tests if they're missing. It looks like |
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)). This diff implements the `where` out variant for static runtime as follows: * Added an out variant for `_s_where` in PyTorch * Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`. Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)! [ghstack-poisoned]
Pull Request resolved: #73438 `aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)). This diff implements the `where` out variant for static runtime as follows: * Added an out variant for `_s_where` in PyTorch * Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`. ghstack-source-id: 150942130 Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)). This diff implements the `where` out variant for static runtime as follows: * Added an out variant for `_s_where` in PyTorch * Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`. Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)! [ghstack-poisoned]
Pull Request resolved: #73438 Add out variant for `where.self`; requires PyTorch core changes as no out variant existed previously Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)! ghstack-source-id: 151197403
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)). This diff implements the `where` out variant for static runtime as follows: * Added an out variant for `_s_where` in PyTorch * Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`. Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)! [ghstack-poisoned]
Pull Request resolved: #73438 Add out variant for `where.self`; requires PyTorch core changes as no out variant existed previously ghstack-source-id: 151281512 Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)). This diff implements the `where` out variant for static runtime as follows: * Added an out variant for `_s_where` in PyTorch * Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`. Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)! [ghstack-poisoned]
Pull Request resolved: #73438 Add out variant for `where.self`; requires PyTorch core changes as no out variant existed previously ghstack-source-id: 151505601 Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!
Summary: Pull Request resolved: #73438 Add out variant for `where.self`; requires PyTorch core changes as no out variant existed previously ghstack-source-id: 151505601 Test Plan: * Existing `where` tests in static runtime pass * CI for core `where` tests Reviewed By: hlu1 Differential Revision: D34469785 fbshipit-source-id: 8a4ebbf38b2364534fbf43812bfcfdf69ea174b3
|
Hey @mikeiovine. |
Stack from ghstack (oldest at bottom):
aten::wherehas no out variant in PyTorch core, and writing one is pretty tricky because it's split intowhereand_s_whereto make autograd stuff easier (ref).This diff implements the
whereout variant for static runtime as follows:_s_wherein PyTorchwhere_outthat duplicates a small amount of broadcasting logic. This function looks a lot likewherein PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly callingat::cpu::_s_where_outinstead ofat::_s_where_out.Differential Revision: D34469785
NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!