Skip to content

Commit 7f206e1

Browse files
committed
Split native_multiheaded attention into cpu/cuda and call into sdp if
avialable
1 parent 44d7ba7 commit 7f206e1

File tree

8 files changed

+336
-40
lines changed

8 files changed

+336
-40
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13268,7 +13268,8 @@
1326813268
- func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor)
1326913269
variants: function
1327013270
dispatch:
13271-
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention
13271+
CPU, NestedTensorCPU: native_multi_head_attention_cpu
13272+
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
1327213273
autogen: _native_multi_head_attention.out
1327313274

1327413275
- func: _scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -322,14 +322,11 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {
322322
const int64_t* tensor_size_ptr = tensor_sizes.data_ptr<int64_t>();
323323
const int64_t* tensor_stride_ptr = tensor_strides.data_ptr<int64_t>();
324324

325-
int64_t offset_constant = (tensor_offsets[1] - tensor_offsets[0]) /
326-
tensor_size_ptr[0] * tensor_stride_ptr[0];
327-
325+
int64_t offset_constant = (tensor_offsets[1] - tensor_offsets[0]) / (tensor_size_ptr[0] * tensor_stride_ptr[0]);
328326
for (int64_t i = 2; i < n_tensors; i++) {
329-
int64_t current_offset_constant =
330-
(tensor_offsets[i] - tensor_offsets[i - 1]) /
331-
tensor_size_ptr[(i - 1) * tensor_stride_0] *
332-
tensor_stride_ptr[(i - 1) * tensor_stride_0];
327+
// TODO: When 0 seq_len nested tensors are allowed we need to guard against this
328+
int64_t previous_numel = tensor_size_ptr[(i - 1) * tensor_stride_0] * tensor_stride_ptr[(i - 1) * tensor_stride_0];
329+
int64_t current_offset_constant = (tensor_offsets[i] - tensor_offsets[i - 1]) / previous_numel;
333330
if (current_offset_constant != offset_constant) {
334331
return false;
335332
}
@@ -431,7 +428,6 @@ std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
431428
{Nnz_kv, num_heads, head_dim},
432429
{nnz_v_stride, head_v_stride, head_dim_stride},
433430
value_impl->get_storage_offsets()[0]);
434-
435431
std::tuple<Tensor, Tensor> attention_and_weights =
436432
at::_efficient_attention_forward(
437433
query_buffer_reshaped.unsqueeze(0),

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <ATen/Parallel.h>
77
#include <ATen/TensorIndexing.h>
88
#include <ATen/cpu/vec/vec256/vec256.h>
9+
#include <ATen/native/transformers/attention.h>
10+
911

1012
#ifndef AT_PER_OPERATOR_HEADERS
1113
#include <ATen/NativeFunctions.h>
@@ -14,7 +16,6 @@
1416
#endif
1517

1618
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
17-
1819
namespace at {
1920

2021
namespace native {
@@ -106,6 +107,17 @@ void transform_bias_rescale_qkv_inner_loop(
106107
}
107108
}
108109

110+
Tensor transform_0213(const Tensor& a) {
111+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
112+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
113+
return a.permute({0, 2, 1, 3})
114+
.contiguous()
115+
.view({a.size(0), a.size(2), a.size(1) * a.size(3)});
116+
}
117+
118+
} // namespace
119+
120+
109121
Tensor bmm_nt(const Tensor& a, const Tensor& b) {
110122
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
111123
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
@@ -118,7 +130,7 @@ Tensor masked_softmax(
118130
Tensor& attn_scores,
119131
c10::optional<Tensor> attn_mask,
120132
const Tensor& query,
121-
c10::optional<int64_t> mask_type = NULL) {
133+
c10::optional<int64_t> mask_type) {
122134
if (query.is_nested() && !attn_mask) {
123135
return at::_nested_tensor_softmax_with_shape(attn_scores, query);
124136
}
@@ -156,13 +168,6 @@ Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b) {
156168
return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
157169
}
158170

159-
Tensor transform_0213(const Tensor& a) {
160-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
161-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
162-
return a.permute({0, 2, 1, 3})
163-
.contiguous()
164-
.view({a.size(0), a.size(2), a.size(1) * a.size(3)});
165-
}
166171

167172
Tensor transform0213_gemm_nt_bias(
168173
const Tensor& a,
@@ -254,8 +259,6 @@ Tensor qkv_projection(
254259
return qkv;
255260
}
256261

257-
} // namespace
258-
259262
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
260263
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu(
261264
const Tensor& qkv,
@@ -312,7 +315,7 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu(
312315
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
313316
}
314317

315-
std::tuple<Tensor, Tensor> native_multi_head_attention(
318+
std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
316319
const Tensor& query,
317320
const Tensor& key,
318321
const Tensor& value,
@@ -692,30 +695,57 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
692695
}
693696

694697
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_math(
695-
const Tensor& query_, const Tensor& key, const Tensor& value,
696-
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
697-
return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
698-
}
698+
const Tensor& query_,
699+
const Tensor& key,
700+
const Tensor& value,
701+
const c10::optional<Tensor>& attn_mask_,
702+
double dropout_p,
703+
bool need_attn_weights,
704+
bool is_causal) {
705+
return at::_scaled_dot_product_attention_math(
706+
query_,
707+
key,
708+
value,
709+
attn_mask_,
710+
dropout_p,
711+
need_attn_weights,
712+
is_causal);
713+
}
699714

700715
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
701716
const Tensor& query_, const Tensor& key, const Tensor& value,
702717
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
718+
// We are not training and we are falling back to math case.
719+
// Inputs are required to be contiguous if nested
720+
Tensor query_contiguous = query_;
721+
Tensor key_contiguous = key;
722+
Tensor value_contiguous = value;
723+
if (query_.is_nested()) {
724+
query_contiguous = query_.contiguous();
725+
}
726+
if (key.is_nested()) {
727+
key_contiguous = key.contiguous();
728+
}
729+
if (value.is_nested()) {
730+
value_contiguous = value.contiguous();
731+
}
732+
703733
auto attn_mask = attn_mask_;
704734
// Naive, composite implementation defined here.
705735
const auto embed_size = query_.size(-1);
706-
const auto query = query_ * (1. / ::sqrt(static_cast<double>(embed_size)));
736+
const auto query = query_contiguous * (1. / ::sqrt(static_cast<double>(embed_size)));
707737
if (is_causal) {
708738
TORCH_CHECK(!attn_mask.has_value(),
709739
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
710-
TORCH_CHECK(!query.is_nested() && !key.is_nested(),
740+
TORCH_CHECK(!query.is_nested() && !key_contiguous.is_nested(),
711741
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
712742

713743
// Replace attn_mask with causal mask; lower triangular elements take part in attention.
714-
const auto L = query.size(-2), S = key.size(-2);
744+
const auto L = query.size(-2), S = key_contiguous.size(-2);
715745
attn_mask = at::ones({L, S}, query.options().dtype(at::kBool)).tril();
716746
}
717747
if (attn_mask.has_value()) {
718-
TORCH_CHECK(!query.is_nested() && !key.is_nested(),
748+
TORCH_CHECK(!query.is_nested() && !key_contiguous.is_nested(),
719749
"_scaled_dot_product_attention: Nested tensors for query / key are not supported "
720750
"when an explicit attn_mask is set");
721751
// Convert boolean mask to additive mask; need to invert mask to indicate what to mask *out*.
@@ -734,7 +764,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
734764
if (dropout_p > 0.0) {
735765
attn = at::dropout(attn, dropout_p, true);
736766
}
737-
const auto output = at::matmul(attn, value);
767+
const auto output = at::matmul(attn, value_contiguous);
738768
return std::make_tuple(output, attn);
739769
}
740770

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
#include <ATen/ATen.h>
3+
#include <c10/macros/Export.h>
4+
5+
namespace at {
6+
namespace native {
7+
8+
TORCH_API Tensor bmm_nt(const Tensor& a, const Tensor& b);
9+
TORCH_API Tensor masked_softmax(
10+
Tensor& attn_scores,
11+
c10::optional<Tensor> attn_mask,
12+
const Tensor& query,
13+
c10::optional<int64_t> mask_type = NULL);
14+
15+
TORCH_API Tensor transform0213_gemm_nt_bias(
16+
const Tensor& a,
17+
const Tensor& b,
18+
const Tensor& c,
19+
const Tensor& query);
20+
21+
TORCH_API Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b);
22+
23+
TORCH_API void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape);
24+
25+
TORCH_API Tensor qkv_projection(
26+
const Tensor& query,
27+
const Tensor& key,
28+
const Tensor& value,
29+
const int64_t embed_dim,
30+
const Tensor& qkv_weight);
31+
32+
} // namespace native
33+
} // namespace at

0 commit comments

Comments
 (0)