Skip to content

Conversation

@ngimel
Copy link
Collaborator

@ngimel ngimel commented Jun 11, 2019

Currently multihead attention for half type is broken

  File "/home/ngimel/pytorch/torch/nn/functional.py", line 3279, in multi_head_attention_forward
    attn_output = torch.bmm(attn_output_weights, v)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'mat2'

because softmax converts half inputs into fp32 inputs. This is unnecessary - all the computations in softmax will be done in fp32 anyway, and the results need to be converted into fp16 for the subsequent batch matrix multiply, so nothing is gained by writing them out in fp32. This PR gets rid of type casting in softmax, so that half works.

@ngimel ngimel requested a review from zhangguanheng66 June 11, 2019 22:01
@pytorchbot pytorchbot added the module: nn Related to torch.nn label Jun 11, 2019
Copy link
Contributor

@zhangguanheng66 zhangguanheng66 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failed tests are not relevant. Ready to Merge. Thanks for the contribution @ngimel

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@zhangguanheng66 merged this pull request in efd20de.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants