-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Let bfloat16 support promotion with other types #41698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a9141ee to
ad2f22a
Compare
💊 CI failures summary and remediationsAs of commit fb8fbd7 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 45 times. |
783cb49 to
642bb8e
Compare
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @xuhdev! Thanks for the PR! I think there are still some tests to sort out, and maybe we should leave fp16 x bfloat16 undefined for the moment?
Let me know your thoughts.
642bb8e to
24b25f7
Compare
|
What devices support bfloat16? In |
ede9551 to
5684589
Compare
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New changes look good. Test in test_type_promotion still needs to be expanded, though.
To answer @nairbv's question about device support: at least CPUs, CUDA devices, and XLA devices support bfloat16 today. The new test suite I'm developing actually validates bfloat16 pretty well. It found a few issues but overall bfloat16 seems to be working as expected where we've implemented the dispatch for it.
5684589 to
fc5f236
Compare
test/test_type_promotion.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't really need this test with the default tensor (it will just be float32 or float64, which you're already testing).
This test looks good. One small change is to test bf + complex number errors (you can construct complex Python numbers using complex(<real>, <imag>)) and to flip the tests so you're adding the number / other tensor + the bfloat16 tensor. You can do this like this:
bf = torch.tensor(5.5, dtype=torch.bfloat16, device=device)
scalars = (2.2, 5, complex(1, -1))
for scalar in scalars:
if isinstance(scalar, complex):
with self.assertRaises(RuntimeError):
a + b
with self.assertRaises(RuntimeError):
b + a
else:
self.assertEqual(a + b, b + a)
self.assertEqual((a + b).dtype,
And then in the next section you want to create the tensor of the dtype upfront:
for dtype in torch.testing.get_all_dtypes():
t = torch.tensor(1, device=device, dtype=dtype)
if dtype in (torch.float32, torch.float64):
self.assertEqual(bf + t, t + bf)
self.assertEqual(bf.dtype, dtype)
...
How does that sound?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree. I adapted your proposal a bit
fc5f236 to
0d1ac22
Compare
c1739cc to
4ce0bd2
Compare
4ce0bd2 to
fb8fbd7
Compare
|
Do we have any update? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fix #40580