66#include < ATen/Parallel.h>
77#include < ATen/TensorIndexing.h>
88#include < ATen/cpu/vec/vec256/vec256.h>
9+ #include < ATen/native/transformers/attention.h>
10+
911
1012#ifndef AT_PER_OPERATOR_HEADERS
1113#include < ATen/NativeFunctions.h>
1416#endif
1517
1618#include < ATen/native/nested/NestedTensorTransformerFunctions.h>
17-
1819namespace at {
1920
2021namespace native {
@@ -106,6 +107,17 @@ void transform_bias_rescale_qkv_inner_loop(
106107 }
107108}
108109
110+ Tensor transform_0213 (const Tensor& a) {
111+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (a.size (1 ));
112+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (a.size (3 ));
113+ return a.permute ({0 , 2 , 1 , 3 })
114+ .contiguous ()
115+ .view ({a.size (0 ), a.size (2 ), a.size (1 ) * a.size (3 )});
116+ }
117+
118+ } // namespace
119+
120+
109121Tensor bmm_nt (const Tensor& a, const Tensor& b) {
110122 auto a_ = a.view ({a.size (0 ) * a.size (1 ), a.size (2 ), a.size (3 )});
111123 auto b_ = b.view ({b.size (0 ) * b.size (1 ), b.size (2 ), b.size (3 )});
@@ -118,7 +130,7 @@ Tensor masked_softmax(
118130 Tensor& attn_scores,
119131 c10::optional<Tensor> attn_mask,
120132 const Tensor& query,
121- c10::optional<int64_t > mask_type = NULL ) {
133+ c10::optional<int64_t > mask_type) {
122134 if (query.is_nested () && !attn_mask) {
123135 return at::_nested_tensor_softmax_with_shape (attn_scores, query);
124136 }
@@ -156,13 +168,6 @@ Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b) {
156168 return c_.view ({a.size (0 ), a.size (1 ), a.size (2 ), b.size (3 )});
157169}
158170
159- Tensor transform_0213 (const Tensor& a) {
160- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (a.size (1 ));
161- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (a.size (3 ));
162- return a.permute ({0 , 2 , 1 , 3 })
163- .contiguous ()
164- .view ({a.size (0 ), a.size (2 ), a.size (1 ) * a.size (3 )});
165- }
166171
167172Tensor transform0213_gemm_nt_bias (
168173 const Tensor& a,
@@ -254,8 +259,6 @@ Tensor qkv_projection(
254259 return qkv;
255260}
256261
257- } // namespace
258-
259262// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
260263std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu (
261264 const Tensor& qkv,
@@ -312,7 +315,7 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu(
312315 return std::make_tuple (q_k_v_s[0 ], q_k_v_s[1 ], q_k_v_s[2 ]);
313316}
314317
315- std::tuple<Tensor, Tensor> native_multi_head_attention (
318+ std::tuple<Tensor, Tensor> native_multi_head_attention_cpu (
316319 const Tensor& query,
317320 const Tensor& key,
318321 const Tensor& value,
@@ -692,30 +695,57 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
692695}
693696
694697std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_math (
695- const Tensor& query_, const Tensor& key, const Tensor& value,
696- const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
697- return at::_scaled_dot_product_attention_math (query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
698- }
698+ const Tensor& query_,
699+ const Tensor& key,
700+ const Tensor& value,
701+ const c10::optional<Tensor>& attn_mask_,
702+ double dropout_p,
703+ bool need_attn_weights,
704+ bool is_causal) {
705+ return at::_scaled_dot_product_attention_math (
706+ query_,
707+ key,
708+ value,
709+ attn_mask_,
710+ dropout_p,
711+ need_attn_weights,
712+ is_causal);
713+ }
699714
700715std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math (
701716 const Tensor& query_, const Tensor& key, const Tensor& value,
702717 const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
718+ // We are not training and we are falling back to math case.
719+ // Inputs are required to be contiguous if nested
720+ Tensor query_contiguous = query_;
721+ Tensor key_contiguous = key;
722+ Tensor value_contiguous = value;
723+ if (query_.is_nested ()) {
724+ query_contiguous = query_.contiguous ();
725+ }
726+ if (key.is_nested ()) {
727+ key_contiguous = key.contiguous ();
728+ }
729+ if (value.is_nested ()) {
730+ value_contiguous = value.contiguous ();
731+ }
732+
703733 auto attn_mask = attn_mask_;
704734 // Naive, composite implementation defined here.
705735 const auto embed_size = query_.size (-1 );
706- const auto query = query_ * (1 . / ::sqrt (static_cast <double >(embed_size)));
736+ const auto query = query_contiguous * (1 . / ::sqrt (static_cast <double >(embed_size)));
707737 if (is_causal) {
708738 TORCH_CHECK (!attn_mask.has_value (),
709739 " _scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True" );
710- TORCH_CHECK (!query.is_nested () && !key .is_nested (),
740+ TORCH_CHECK (!query.is_nested () && !key_contiguous .is_nested (),
711741 " _scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True" );
712742
713743 // Replace attn_mask with causal mask; lower triangular elements take part in attention.
714- const auto L = query.size (-2 ), S = key .size (-2 );
744+ const auto L = query.size (-2 ), S = key_contiguous .size (-2 );
715745 attn_mask = at::ones ({L, S}, query.options ().dtype (at::kBool )).tril ();
716746 }
717747 if (attn_mask.has_value ()) {
718- TORCH_CHECK (!query.is_nested () && !key .is_nested (),
748+ TORCH_CHECK (!query.is_nested () && !key_contiguous .is_nested (),
719749 " _scaled_dot_product_attention: Nested tensors for query / key are not supported "
720750 " when an explicit attn_mask is set" );
721751 // Convert boolean mask to additive mask; need to invert mask to indicate what to mask *out*.
@@ -734,7 +764,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
734764 if (dropout_p > 0.0 ) {
735765 attn = at::dropout (attn, dropout_p, true );
736766 }
737- const auto output = at::matmul (attn, value );
767+ const auto output = at::matmul (attn, value_contiguous );
738768 return std::make_tuple (output, attn);
739769}
740770
0 commit comments