Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3495,6 +3495,21 @@ def test_multihead_attn_key_padding_mask():
test_multihead_attn_no_masking() # Test MultiheadAttention without masking
test_multihead_attn_key_padding_mask() # Test MultiheadAttention with src lengths

@repeat_test_for_types(ALL_TENSORTYPES)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_multihead_attention_dtype(self, dtype=torch.float):
embed_dim = 128
num_heads = 8
sl = 10
bs = 8
model = nn.MultiheadAttention(embed_dim, num_heads).cuda().to(dtype)
q = torch.randn(sl, bs, embed_dim, device="cuda", dtype=dtype)
k = torch.randn(sl, bs, embed_dim, device="cuda", dtype=dtype)
v = torch.randn(sl, bs, embed_dim, device="cuda", dtype=dtype)
out = model(q, k, v)
self.assertEqual(q.size(), out[0].size())
self.assertEqual(dtype, out[0].dtype)

def test_normalize(self):
inputs = torch.randn(1, 3, 4, 4, requires_grad=True)
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
Expand Down
3 changes: 1 addition & 2 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3272,8 +3272,7 @@ def multi_head_attention_forward(query, # type: Tensor
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)

attn_output_weights = softmax(
attn_output_weights.float(), dim=-1,
dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
attn_output_weights, dim=-1)
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)

attn_output = torch.bmm(attn_output_weights, v)
Expand Down