-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add non_blocking to Tensor/Module.to #7312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
3b81824
Add non_blocking to Tensor/Module.to
ssnl 96c6721
flake8
ssnl 827a6c6
Add argparse tests
ssnl 49bcd56
cpp parse
ssnl 15d2345
Use C++ parser
ssnl 7a5fb4e
use a commong parse function with Tensor.to
ssnl 4ce165f
fix test_jit
ssnl 9b30c88
use THPObjectPtr
ssnl f1a8ec5
increase refcount for None, True, and False
ssnl b6fddcf
address comments
ssnl 036f618
address comments
ssnl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| #pragma once | ||
|
|
||
| #include "torch/csrc/python_headers.h" | ||
| #include <ATen/ATen.h> | ||
|
|
||
| #include "torch/csrc/utils/python_arg_parser.h" | ||
| #include "torch/csrc/utils/device.h" | ||
|
|
||
| namespace torch { namespace autograd { namespace utils { | ||
|
|
||
| inline std::tuple<at::optional<torch::Device>, at::optional<at::ScalarType>, bool> | ||
| parse_to_conversion(PyObject *args, PyObject *kwargs) { | ||
| static PythonArgParser parser({ | ||
| "to(Device device=None, ScalarType dtype=None, bool non_blocking=False)", | ||
| "to(ScalarType dtype, bool non_blocking=False)", | ||
| "to(Tensor tensor, bool non_blocking=False)", | ||
| }); | ||
| ParsedArgs<3> parsed_args; | ||
| auto r = parser.parse(args, kwargs, parsed_args); | ||
| if (r.idx == 0) { | ||
| return std::make_tuple(r.deviceOptional(0), r.scalartypeOptional(1), r.toBool(2)); | ||
| } else if (r.idx == 1) { | ||
| return std::make_tuple(at::nullopt, r.scalartype(0), r.toBool(1)); | ||
| } else { | ||
| auto tensor = r.tensor(0); | ||
| return std::make_tuple( | ||
| torch::tensor::getDevice(tensor), | ||
| tensor.type().scalarType(), | ||
| r.toBool(1) | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| }}} // namespace torch::autograd::utils |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.