-
Notifications
You must be signed in to change notification settings - Fork 26.3k
detach returned tensors from source tensors in tensor constructors #11815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
detach returned tensors from source tensors in tensor constructors #11815
Conversation
c599748 to
d92b7dc
Compare
d92b7dc to
9e9fdf3
Compare
| if (r.idx == 0) { | ||
| PyObject* data = r.pyobject(0); | ||
| if (THPVariable_Check(data)) { | ||
| PyErr_WarnEx(PyExc_UserWarning, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| a[0] = 7. | ||
| self.assertEqual(5., res1[0].item()) | ||
|
|
||
| def test_tensor_factory_copy_var(self): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
…ent dtype and device
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
test/test_torch.py
Outdated
| # test torch.as_tensor() | ||
| source = source.add(1) # make source non-leaf | ||
| check_copy(torch.as_tensor(source), source.is_leaf, source.requires_grad) | ||
| check_copy(torch.as_tensor(source, dtype=torch.float), True, False) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/utils/tensor_new.cpp
Outdated
| r.pyobject(0)) | ||
| .set_requires_grad(r.toBool(3)); | ||
| data); | ||
| new_tensor.detach_(); // making new_tensor a leaf node |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| PyObject* data = r.pyobject(0); | ||
| if (THPVariable_Check(data)) { | ||
| PyErr_WarnEx(PyExc_UserWarning, | ||
| "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if (r.idx == 0) { | ||
| PyObject* data = r.pyobject(0); | ||
| if (THPVariable_Check(data)) { | ||
| PyErr_WarnEx(PyExc_UserWarning, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/utils/tensor_new.cpp
Outdated
| args_requires_grad) | ||
| .set_requires_grad(r.toBool(3)); | ||
| type_inference); | ||
| new_tensor.detach_(); // making new_tensor a leaf node |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/utils/tensor_new.cpp
Outdated
| ParsedArgs<3> parsed_args; | ||
| auto r = parser.parse(args, kwargs, parsed_args); | ||
| if (r.idx == 0) { | ||
| at::optional<Device> device_opt = r.deviceOptional(2); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Semantics of torch.as_tensor need to be clarified.
test/test_torch.py
Outdated
| def check_copy(copy, copy_is_leaf, copy_requires_grad): | ||
| self.assertEqual(copy.data, source.data) | ||
| self.assertTrue(copy.is_leaf == copy_is_leaf) | ||
| self.assertTrue(copy.requires_grad == copy_requires_grad) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/_torch_docs.py
Outdated
| :func:`as_tensor()` reads out 'the data' from whatever it is passed, | ||
| and constructs a leaf variable. Therefore ``tensor.as_tensor(x, dtype=dtype, device=device)`` | ||
| is equivalent to ``x.clone().detach().to(dtype, device)``. | ||
| The equivalents using ``clone()`` and ``detach()`` are recommended. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if (r.idx == 0) { | ||
| PyObject* data = r.pyobject(0); | ||
| if (THPVariable_Check(data)) { | ||
| PyErr_WarnEx(PyExc_UserWarning, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/utils/tensor_new.cpp
Outdated
| typeWithDefault(r, 1, 2, type), r.deviceOptional(2), r.pyobject(0), false, false, type_inference); | ||
| bool is_copy = false; | ||
|
|
||
| if (THPVariable_Check(data)) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Currently It doesn't really make sense. Does this patch remove that? |
|
@ssnl Yes, it warns because |
a211893 to
36d298c
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - fix PR pytorch#11061 by moving `detach_()` and `set_requires_grad()` to `torch.tensor_ctor()` and `tensor.new_tensor`, and also removed warnings and `args_requires_grad` from `internal_new_from_data ` - with this patch, the returned tensor from `tensor_ctor()` and `new_tensor` will be detached from source tensor, and set requires_grad based on the input args - `torch.as_tensor` retains its behavior as documented gchanan apaszke Pull Request resolved: pytorch#11815 Differential Revision: D9932713 Pulled By: weiyangfb fbshipit-source-id: 4290cbc57bd449954faadc597c24169a7b2d8259
detach_()andset_requires_grad()totorch.tensor_ctor()andtensor.new_tensor, and also removed warnings andargs_requires_gradfrominternal_new_from_datatensor_ctor()andnew_tensorwill be detached from source tensor, and set requires_grad based on the input argstorch.as_tensorretains its behavior as documented@gchanan @apaszke