-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[fix] type promotion atan2 #43466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] type promotion atan2 #43466
Conversation
|
@nairbv Please review:) Also not sure where the test should be?
|
I'd go with test_type_promotion. There are tests in there already for unary versions of trigonometric ops, and we'd specifically be testing for the promotion behavior. In some cases our promotion behavior might also differ from numpy. |
|
Sure. Will add test there. Thanks. |
|
@nairbv Please review :) |
|
Gentle Ping :) |
|
@nairbv Gentle Ping :) |
|
Gentle Ping :) |
1 similar comment
|
Gentle Ping :) |
aten/src/ATen/native/BinaryOps.cpp
Outdated
| Tensor result = at::empty({0}, self.options()); | ||
| return native::atan2_out(result, self, other); | ||
| Tensor result; | ||
| auto iter = TensorIterator::binary_op(result, self, other); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cool and completely correct. Would you mind updating it, though? We now have TensorIterator::binary_float_op, which applies to atan2. That makes it so integer inputs to atan2 are always promoted to float values (just like div).
Would you make atan2 a binary_float_op and test that integer inputs are accepted and return the expected float values? This would require expanding your test in test_type_promotion and the list of types we test atan2 to with here:
Line 19881 in d2b4534
| ('atan2', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-2, 1e-5, 1e-5, _float_types), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Thanks!
| # https://github.com/pytorch/pytorch/issues/28502 | ||
| a = torch.tensor([[True, True], [False, True]], device=device) | ||
| self.assertEqual(a.t() == 0, a.t() == False) # noqa: E712 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the @dtypes decorator with pairs of dtypes if you like. For example:
@dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())
def test_atan2_type_promotion(self, device, dtypes):
dtype1, dtype2 = dtypes
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Will use that thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it will be hard to use as include_half is chosen based on device type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use
@dtypesIfCUDA(...)
@dtypes(...)
CUDA will get the dtypes in the first decorator, all other device types will get the dtypes in the second decorator. If you want CPU-specific dtypes there's also @dtypesIfCPU.
Codecov Report
@@ Coverage Diff @@
## master #43466 +/- ##
=======================================
Coverage 67.83% 67.83%
=======================================
Files 384 384
Lines 49962 49962
=======================================
+ Hits 33892 33893 +1
+ Misses 16070 16069 -1
Continue to review full report at Codecov.
|
💊 CI failures summary and remediationsAs of commit a558f43 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 21 times. |
|
The updated tests fail with Which makes sense. Updating the generated test code for this. Note that |
| def is_float(dtype): | ||
| return dtype in torch.testing.get_all_fp_dtypes(include_half=include_half, include_bfloat16=False) | ||
|
|
||
| def get_binary_float_result_type(x, y): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any function (like torch.result_type) which actually does this for binary_float_op promotion behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately no. Which makes torch.result_type kind of misleading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok.
|
PR might need an update from XLA side as well, I guess. |
|
I will take a look at the XLA failure. @kshitij12345 Could you open an issue under pytorch/xla to track this? |
We can also skip the XLA tests for now: |
It is fine, I made the change on xla side and verified that test passed with the fix. |
|
@JackCaoG Thanks for the fix. However it looks like the XLA failures still exist. |
|
I just merged the fix, you should see the test passing now |
|
Great thanks! |
|
@mruberry Gentle Ping:) |
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work as usual, @kshitij12345!
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #43360