Skip to content

Conversation

@mruberry
Copy link
Collaborator

@mruberry mruberry commented Aug 20, 2020

Implements bfloat16 type promotion consistent with JAX (see https://jax.readthedocs.io/en/latest/type_promotion.html), addressing issue #43049.

  • bfloat16 x float16 -> float32
  • bfloat16 x complex64 -> complex64
  • bfloat16 x complex128 -> complex128

Existing tests, after updates, are sufficient to validate the new behavior.

cc @xuhdev

@mruberry mruberry requested a review from gchanan August 20, 2020 10:55
self.assertEqual((bf + scalar).dtype, torch.bfloat16)
self.assertEqual((scalar + bf).dtype, torch.bfloat16)
self.assertEqual(scalar + bf, bf + scalar)
with self.assertRaises(RuntimeError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a complex scalar to the for loop above? That seems to capture the intent of this test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I also simplified the loop.

self.assertEqual(bf + t, t + bf)
if dtype in (torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble):
# Handles bfloat16 x float16 -> float32 promotion
expected_dtype = dtype if dtype != torch.half else torch.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you rationalize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rationalized.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@dr-ci
Copy link

dr-ci bot commented Aug 21, 2020

💊 CI failures summary and remediations

As of commit f4922f4 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_macos_10_13_py3_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Aug 21 00:36:07 [E request_callback_no_python.cpp:618] Received error while processing request type 2: RuntimeError: Can not pickle torch.futures.Future
Aug 21 00:36:07 At: 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Aug 21 00:36:07  
Aug 21 00:36:07 [E request_callback_no_python.cpp:618] Received error while processing request type 2: RuntimeError: Can not pickle torch.futures.Future 
Aug 21 00:36:07  
Aug 21 00:36:07 At: 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Aug 21 00:36:07  
Aug 21 00:36:07 [E request_callback_no_python.cpp:618] Received error while processing request type 2: RuntimeError: Can not pickle torch.futures.Future 
Aug 21 00:36:07  
Aug 21 00:36:07 At: 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Aug 21 00:36:07   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Aug 21 00:36:07  
Aug 21 00:36:07 ok (1.340s) 
Aug 21 00:36:09   test_return_future_remote (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.283s) 
Aug 21 00:36:10   test_return_local_rrefs (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.365s) 
Aug 21 00:36:11   test_rpc_return_rref (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.289s) 
Aug 21 00:36:19   test_rpc_timeouts (__main__.ProcessGroupRpcTestWithSpawn) ... ok (7.873s) 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 1 time.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 3aec118.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants