Skip to content

Conversation

@mikeiovine
Copy link

@mikeiovine mikeiovine commented Feb 25, 2022

Stack from ghstack (oldest at bottom):

aten::where has no out variant in PyTorch core, and writing one is pretty tricky because it's split into where and _s_where to make autograd stuff easier (ref).

This diff implements the where out variant for static runtime as follows:

  • Added an out variant for _s_where in PyTorch
  • Added a static runtime implementation for where_out that duplicates a small amount of broadcasting logic. This function looks a lot like where in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling at::cpu::_s_where_out instead of at::_s_where_out.

Differential Revision: D34469785

NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!

`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)).

This diff implements the `where` out variant for static runtime as follows:

* Added an out variant for `_s_where` in PyTorch
* Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`.

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 25, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/0a5b6bbbd230be80fd13adffc50cdd3564d9e6b5/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Feb 25, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit d7fba92 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added cla signed oncall: jit Add this issue/PR to JIT oncall triage queue labels Feb 25, 2022
mikeiovine pushed a commit that referenced this pull request Feb 25, 2022
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)).

This diff implements the `where` out variant for static runtime as follows:

* Added an out variant for `_s_where` in PyTorch
* Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`.

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!

ghstack-source-id: 149977789
Pull Request resolved: #73438
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)).

This diff implements the `where` out variant for static runtime as follows:

* Added an out variant for `_s_where` in PyTorch
* Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`.

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!

[ghstack-poisoned]
mikeiovine pushed a commit that referenced this pull request Feb 25, 2022
Pull Request resolved: #73438

`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)).

This diff implements the `where` out variant for static runtime as follows:

* Added an out variant for `_s_where` in PyTorch
* Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`.
ghstack-source-id: 150007304

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!

TORCH_META_FUNC(_s_where) (const Tensor& condition, const Tensor& self, const Tensor& other) {
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not very good but I'm not sure there is necessarily a better way to do this. What is not good about it is that meta functions are not supposed to do any nontrivial compute, because they may get invoked as part of, e.g., just shape propagation (meta functions). If condition happens to be a meta tensor this will be relatively cheap, but something that is supposed to work is, for example, calling meta::_s_where(cpu_tensors...) and that will useless do some compute in this case.

What makes matters a little better is that TensorIterator::build (used everywhere here) DOES do some temporary allocations, and it's kind of hacked around by just checking if any input argument is meta, and then avoiding allocating temporaries if it's the case. So I'm willing to let this slide, but let's have a comment+issue about it.

cc @ysiraichi

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll add a comment - it's also worth noting that this byte -> bool conversion is deprecated, so the conversion will not be needed soon hopefully

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not qualified for the static runtime bit, but everything else seems to be fine. This adds meta support for _s_where, do we get test coverage for free?

@mikeiovine
Copy link
Author

not qualified for the static runtime bit, but everything else seems to be fine. This adds meta support for _s_where, do we get test coverage for free?

I'm not sure, how are meta functions usually tested? Would be happy to add tests if they're missing.

It looks like _s_where was removed by #73468, but this change was reverted, so I'm putting this PR on the back-burner for a few days until all of this stuff stabilizes.

`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)).

This diff implements the `where` out variant for static runtime as follows:

* Added an out variant for `_s_where` in PyTorch
* Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`.

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!

[ghstack-poisoned]
mikeiovine pushed a commit that referenced this pull request Mar 9, 2022
Pull Request resolved: #73438

`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)).

This diff implements the `where` out variant for static runtime as follows:

* Added an out variant for `_s_where` in PyTorch
* Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`.
ghstack-source-id: 150942130

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)).

This diff implements the `where` out variant for static runtime as follows:

* Added an out variant for `_s_where` in PyTorch
* Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`.

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!

[ghstack-poisoned]
mikeiovine pushed a commit that referenced this pull request Mar 11, 2022
Pull Request resolved: #73438

Add out variant for `where.self`; requires PyTorch core changes as no out variant existed previously

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!
ghstack-source-id: 151197403
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)).

This diff implements the `where` out variant for static runtime as follows:

* Added an out variant for `_s_where` in PyTorch
* Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`.

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!

[ghstack-poisoned]
mikeiovine pushed a commit that referenced this pull request Mar 14, 2022
Pull Request resolved: #73438

Add out variant for `where.self`; requires PyTorch core changes as no out variant existed previously
ghstack-source-id: 151281512

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!
`aten::where` has no out variant in PyTorch core, and writing one is pretty tricky because it's split into `where` and `_s_where` to make autograd stuff easier ([ref](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4773)).

This diff implements the `where` out variant for static runtime as follows:

* Added an out variant for `_s_where` in PyTorch
* Added a static runtime implementation for `where_out` that duplicates a small amount of broadcasting logic. This function looks a lot like `where` in PyTorch, but it's slightly more efficient - static runtime can skip the device check, since we take it as given that all tensors are on CPU. We can also skip a round of dispatch by directly calling `at::cpu::_s_where_out` instead of `at::_s_where_out`.

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!

[ghstack-poisoned]
mikeiovine pushed a commit that referenced this pull request Mar 16, 2022
Pull Request resolved: #73438

Add out variant for `where.self`; requires PyTorch core changes as no out variant existed previously
ghstack-source-id: 151505601

Differential Revision: [D34469785](https://our.internmc.facebook.com/intern/diff/D34469785/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34469785/)!
facebook-github-bot pushed a commit that referenced this pull request Mar 17, 2022
Summary:
Pull Request resolved: #73438

Add out variant for `where.self`; requires PyTorch core changes as no out variant existed previously
ghstack-source-id: 151505601

Test Plan:
* Existing `where` tests in static runtime pass
* CI for core `where` tests

Reviewed By: hlu1

Differential Revision: D34469785

fbshipit-source-id: 8a4ebbf38b2364534fbf43812bfcfdf69ea174b3
@github-actions
Copy link
Contributor

Hey @mikeiovine.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants