Skip to content

Commit 596e945

Browse files
committed
macro fun
1 parent dfa85b8 commit 596e945

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,11 @@ std::tuple<Tensor, Tensor> native_multi_head_attention(
388388
? get_nested_tensor_impl(query)->get_nested_size_tensor().size(0)
389389
: query.sizes()[0];
390390
auto T = query.is_nested() ? 0 : query.sizes()[1];
391+
const auto dim_per_head = D / num_head;
391392
#endif
392393
#ifdef USE_CUDA
393-
const auto dim_per_head = D / num_head;
394-
if (dim_per_head % 8 == 0 && query.is_cuda()) {
394+
const int64_t sdp_dim_per_head = D / num_head;
395+
if (sdp_dim_per_head % 8 == 0 && query.is_cuda()) {
395396
sdp::sdp_params kernel_params{
396397
query, key, value, mask.has_value(), 0.0, need_weights, false};
397398
auto backend = select_sdp_backend(kernel_params);
@@ -401,9 +402,9 @@ std::tuple<Tensor, Tensor> native_multi_head_attention(
401402
auto x = at::linear(query, qkv_weight, qkv_bias);
402403
auto chunks = x.chunk(3, -1);
403404
auto x_size_0 = x.size(0);
404-
chunks[0] = chunks[0].view({x_size_0, -1, num_head, dim_per_head});
405-
chunks[1] = chunks[1].view({x_size_0, -1, num_head, dim_per_head});
406-
chunks[2] = chunks[2].view({x_size_0, -1, num_head, dim_per_head});
405+
chunks[0] = chunks[0].view({x_size_0, -1, num_head, sdp_dim_per_head});
406+
chunks[1] = chunks[1].view({x_size_0, -1, num_head, sdp_dim_per_head});
407+
chunks[2] = chunks[2].view({x_size_0, -1, num_head, sdp_dim_per_head});
407408
chunks[0] = chunks[0].transpose(1, 2);
408409
chunks[1] = chunks[1].transpose(1, 2);
409410
chunks[2] = chunks[2].transpose(1, 2);

test/test_transformers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,6 @@ def test_transformerencoder_square_input(self, with_no_grad, training, enable_ne
288288
@parametrize("training", [True, False])
289289
@parametrize("enable_nested_tensor", [True, False])
290290
@parametrize("device", device_list)
291-
@sdp_kernel(enable_math=True)
292291
def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device):
293292
def get_a_test_layer(activation, batch_first=False):
294293
d_model = 4

0 commit comments

Comments
 (0)