-
Notifications
You must be signed in to change notification settings - Fork 26.3k
is_numpy_scalar should also consider bool and complex types #43644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I couldn't find any Python exposure of |
You'd have to test it in C++ if you wanted to test it directly. Instead, however, you can probably test the desired effect (testing |
68ce73e to
d7e81e5
Compare
|
@mruberry I've added test now |
d7e81e5 to
052d19a
Compare
Codecov Report
@@ Coverage Diff @@
## master #43644 +/- ##
=======================================
Coverage 69.28% 69.29%
=======================================
Files 379 379
Lines 47035 47035
=======================================
+ Hits 32590 32591 +1
+ Misses 14445 14444 -1
Continue to review full report at Codecov.
|
test/test_torch.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the complex tensors won't actually have any complex values and there's only a single 1D tensor. What about extending the test slightly to support complex values and multiple tensor sizes? It could do something like this:
is_complex = dtype in (torch.cfloat, torch.cdouble)
if is_complex:
tensors = (...)
else:
tensors = (
torch.tensor(5, dtype=dtype, device=device),
torch.tensor([0, 1, 2], dtype=dtype, device=device),
torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device),
)
for tensor in tensors:
np_array = tensor.cpu().numpy()
for t, a in product((tensor[0], tensor[0].item()), (np_array[0], np_array[0].item()):
self.assertEqual(t, a)
if not is_complex:
self.assertTrue(t == a)
if is_complex:
self.assertTrue(t[0] == np_array[0].item())
self.assertTrue(t[0].item() == np_array[0])
self.assertFalse(t[0] == np_array[0])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per the following discussion, if you follow this pattern the later conditionals will probably need to check dtype is torch.cfloat instead of the is_complex boolean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't self.assertFalse(t[0] == np_array[0]) be true only for cfloat, not for cdouble?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks. Anyway, I have structured the test a bit differently.
|
Hey @xuhdev! This looks really cool. Just have one question about testing for you to review. Looking forward to hearing your thoughts! |
052d19a to
0a74895
Compare
944b26f to
5faeaa0
Compare
test/test_torch.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry do we have any function similar to _make_tensors in TestTorchDeviceType?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recap of offline discussion: #43451 is adding a make_tensor method to common_utils.py.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@xuhdev can you rebase this PR again? phabricator is complaining since another PR landed touching files edited in this PR |
Before this PR, ```python import torch import numpy as np a = torch.tensor([1, 2], dtype=torch.bool) c = np.array([1, 2], dtype=np.bool) print(a[0] == c[0]) a = torch.tensor([1, 2], dtype=torch.complex64) c = np.array([1, 2], dtype=np.complex64) print(a[0] == c[0]) # This case is still broken a = torch.tensor([1 + 1j, 2 + 2j], dtype=torch.complex64) c = np.array([1 + 1j, 2 + 2j], dtype=np.complex64) print(a[0] == c[0]) ``` outputs ``` False False False ``` After this PR, it outputs: ``` tensor(True) /home/hong/xusrc/pytorch/torch/tensor.py:25: ComplexWarning: Casting complex values to real discards the imaginary part return f(*args, **kwargs) tensor(True) tensor(False) ``` Related issue: pytorch#43579
5faeaa0 to
fc4728d
Compare
|
@anjali411 Done! |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@anjali411 merged this pull request in 4bb5d33. |
Before this PR,
outputs
After this PR, it outputs:
Related issue: #43579
cc @anjali411 @mruberry