@@ -388,10 +388,11 @@ std::tuple<Tensor, Tensor> native_multi_head_attention(
388388 ? get_nested_tensor_impl (query)->get_nested_size_tensor ().size (0 )
389389 : query.sizes ()[0 ];
390390 auto T = query.is_nested () ? 0 : query.sizes ()[1 ];
391+ const auto dim_per_head = D / num_head;
391392#endif
392393#ifdef USE_CUDA
393- const auto dim_per_head = D / num_head;
394- if (dim_per_head % 8 == 0 && query.is_cuda ()) {
394+ const int64_t sdp_dim_per_head = D / num_head;
395+ if (sdp_dim_per_head % 8 == 0 && query.is_cuda ()) {
395396 sdp::sdp_params kernel_params{
396397 query, key, value, mask.has_value (), 0.0 , need_weights, false };
397398 auto backend = select_sdp_backend (kernel_params);
@@ -401,9 +402,9 @@ std::tuple<Tensor, Tensor> native_multi_head_attention(
401402 auto x = at::linear (query, qkv_weight, qkv_bias);
402403 auto chunks = x.chunk (3 , -1 );
403404 auto x_size_0 = x.size (0 );
404- chunks[0 ] = chunks[0 ].view ({x_size_0, -1 , num_head, dim_per_head });
405- chunks[1 ] = chunks[1 ].view ({x_size_0, -1 , num_head, dim_per_head });
406- chunks[2 ] = chunks[2 ].view ({x_size_0, -1 , num_head, dim_per_head });
405+ chunks[0 ] = chunks[0 ].view ({x_size_0, -1 , num_head, sdp_dim_per_head });
406+ chunks[1 ] = chunks[1 ].view ({x_size_0, -1 , num_head, sdp_dim_per_head });
407+ chunks[2 ] = chunks[2 ].view ({x_size_0, -1 , num_head, sdp_dim_per_head });
407408 chunks[0 ] = chunks[0 ].transpose (1 , 2 );
408409 chunks[1 ] = chunks[1 ].transpose (1 , 2 );
409410 chunks[2 ] = chunks[2 ].transpose (1 , 2 );
0 commit comments