Skip to content

Conversation

@xuhdev
Copy link
Collaborator

@xuhdev xuhdev commented Jul 20, 2020

Fix #40580

@xuhdev xuhdev requested a review from mruberry July 20, 2020 22:16
@xuhdev xuhdev force-pushed the bfloat16-promotion branch from a9141ee to ad2f22a Compare July 20, 2020 22:17
@dr-ci
Copy link

dr-ci bot commented Jul 20, 2020

💊 CI failures summary and remediations

As of commit fb8fbd7 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 45 times.

@xuhdev xuhdev requested review from anjali411, nairbv and ngimel July 20, 2020 22:34
@gchanan gchanan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 21, 2020
@xuhdev xuhdev force-pushed the bfloat16-promotion branch 2 times, most recently from 783cb49 to 642bb8e Compare July 21, 2020 18:17
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @xuhdev! Thanks for the PR! I think there are still some tests to sort out, and maybe we should leave fp16 x bfloat16 undefined for the moment?

Let me know your thoughts.

@nairbv
Copy link
Collaborator

nairbv commented Jul 22, 2020

What devices support bfloat16?

In torch/testing/__init__.py there's a function get_all_math_dtypes(device), but it doesn't appear to set include_bfloat16=True for any device type. If we fix that to handle bfloat16 too, I think it would enable some other arithmetical mixed-type testing.

@xuhdev xuhdev force-pushed the bfloat16-promotion branch 3 times, most recently from ede9551 to 5684589 Compare July 23, 2020 02:03
@xuhdev xuhdev requested a review from mruberry July 23, 2020 05:06
@mruberry mruberry self-requested a review July 23, 2020 06:52
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New changes look good. Test in test_type_promotion still needs to be expanded, though.

To answer @nairbv's question about device support: at least CPUs, CUDA devices, and XLA devices support bfloat16 today. The new test suite I'm developing actually validates bfloat16 pretty well. It found a few issues but overall bfloat16 seems to be working as expected where we've implemented the dispatch for it.

@xuhdev xuhdev force-pushed the bfloat16-promotion branch from 5684589 to fc5f236 Compare July 23, 2020 17:53
@xuhdev xuhdev requested a review from mruberry July 23, 2020 17:54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't really need this test with the default tensor (it will just be float32 or float64, which you're already testing).

This test looks good. One small change is to test bf + complex number errors (you can construct complex Python numbers using complex(<real>, <imag>)) and to flip the tests so you're adding the number / other tensor + the bfloat16 tensor. You can do this like this:

bf = torch.tensor(5.5, dtype=torch.bfloat16, device=device)
scalars = (2.2, 5, complex(1, -1))
for scalar in scalars:
  if isinstance(scalar, complex):
    with self.assertRaises(RuntimeError):
      a + b
    with self.assertRaises(RuntimeError):
      b + a
  else:
    self.assertEqual(a + b, b + a)
    self.assertEqual((a + b).dtype, 

And then in the next section you want to create the tensor of the dtype upfront:

for dtype in torch.testing.get_all_dtypes():
  t = torch.tensor(1, device=device, dtype=dtype)
  if dtype in (torch.float32, torch.float64):
    self.assertEqual(bf + t, t + bf)
    self.assertEqual(bf.dtype, dtype)
...

How does that sound?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree. I adapted your proposal a bit

@xuhdev xuhdev force-pushed the bfloat16-promotion branch from fc5f236 to 0d1ac22 Compare July 23, 2020 18:32
@xuhdev xuhdev requested a review from mruberry July 23, 2020 18:32
@xuhdev xuhdev force-pushed the bfloat16-promotion branch 2 times, most recently from c1739cc to 4ce0bd2 Compare July 23, 2020 18:36
@xuhdev xuhdev force-pushed the bfloat16-promotion branch from 4ce0bd2 to fb8fbd7 Compare July 24, 2020 00:24
@xuhdev
Copy link
Collaborator Author

xuhdev commented Jul 29, 2020

Do we have any update?

@mruberry
Copy link
Collaborator

Hey @xuhdev, sorry to keep you waiting. I had a chance to talk to @nairbv offline and we're satisfied. Nice job!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 344defc.

@xuhdev xuhdev deleted the bfloat16-promotion branch August 18, 2020 23:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Multiply a bfloat16 tensor with a double tensor leads to RuntimeError

6 participants