Skip to content

Commit df14650

Browse files
drisspgpytorchmergebot
authored andcommitted
[SDPA] Update SDPA API and make function Public (#92189)
# Summary In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function. ## Changes ### API Previously the the function signature was: `scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)` Updated signature: `scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor` This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor. #### Reasoning: The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated. The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing. Discussed with folks at FAIR/Xformers and +1 this API change. #### Make function Public In preparation for the pt 2.0 launch we make the function public to start to generate user feedback Pull Request resolved: #92189 Approved by: https://github.com/cpuhrsch
1 parent 1237cf6 commit df14650

26 files changed

+428
-426
lines changed

aten/src/ATen/autocast_mode.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
390390
KERNEL(rnn_tanh_cell, lower_precision_fp)
391391
KERNEL(rnn_relu_cell, lower_precision_fp)
392392
KERNEL(_scaled_dot_product_flash_attention, lower_precision_fp)
393-
KERNEL(_scaled_dot_product_attention, lower_precision_fp)
393+
KERNEL(scaled_dot_product_attention, lower_precision_fp)
394394

395395
// fp32
396396
KERNEL(acos, fp32)

aten/src/ATen/functorch/BatchRulesDecompositions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
203203
OP_DECOMPOSE(rrelu);
204204
OP_DECOMPOSE(prelu);
205205
OP_DECOMPOSE2(softmax, int);
206+
OP_DECOMPOSE(scaled_dot_product_attention);
206207
OP_DECOMPOSE(special_gammainc);
207208
OP_DECOMPOSE(special_gammaincc);
208209
OP_DECOMPOSE(special_logit);

aten/src/ATen/native/native_functions.yaml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13948,21 +13948,27 @@
1394813948
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
1394913949
autogen: _native_multi_head_attention.out
1395013950

13951+
# TODO: THIS NEEDS TO BE REMOVED BUT PEOPLE HAVE TRAINED THEIR MODELS WITH THIS OP BUILTIN
1395113952
- func: _scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
1395213953
python_module: nn
1395313954
variants: function
1395413955
autogen: _scaled_dot_product_attention.out
1395513956

13957+
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> Tensor
13958+
python_module: nn
13959+
variants: function
13960+
autogen: scaled_dot_product_attention.out
13961+
1395613962
# This aten function is kept so that we can test the choice function from Python
13957-
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> int
13963+
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> int
1395813964
dispatch:
1395913965
CPU, NestedTensorCPU, Meta: _fused_sdp_choice_cpp
1396013966
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
1396113967

13962-
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
13968+
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> (Tensor, Tensor)
1396313969
variants: function
1396413970

13965-
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool return_softmax=False, bool is_causal=False) -> (Tensor, Tensor, Tensor)
13971+
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False) -> (Tensor, Tensor)
1396613972
dispatch:
1396713973
CUDA: _scaled_dot_product_flash_attention_cuda
1396813974
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
@@ -13980,7 +13986,7 @@
1398013986
dispatch:
1398113987
CUDA: _chunk_grad_outputs_efficient_attention
1398213988
# Returns ouput, softmax_logsumexp, softmax
13983-
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, bool return_softmax, float dropout_p, bool is_causal) -> (Tensor, Tensor, Tensor)
13989+
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal) -> (Tensor, Tensor)
1398413990
variants: function
1398513991
dispatch:
1398613992
CUDA: _flash_attention_forward

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,12 +321,11 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {
321321

322322
} // namespace
323323

324-
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_nestedtensor_cuda(
324+
std::tuple<Tensor, Tensor> _scaled_dot_product_flash_attention_nestedtensor_cuda(
325325
const Tensor& query,
326326
const Tensor& key,
327327
const Tensor& value,
328328
double dropout_p,
329-
bool return_softmax,
330329
bool is_causal) {
331330
TORCH_CHECK(false, "There are currently cuda memory errors being returned from this path.")
332331
// Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
@@ -373,13 +372,12 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_nestedten
373372
cumulative_sequence_length_k,
374373
max_seqlen_batch_q,
375374
max_seqlen_batch_k,
376-
return_softmax,
377375
dropout_p,
378376
is_causal);
379377
// Reshape output to convert nnz to batch_size and seq_len
380378
Tensor attention = std::get<0>(attention_and_lse_and_softmax);
381379
attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()).transpose(1,2);
382-
return std::tie(attention, std::get<1>(attention_and_lse_and_softmax), std::get<2>(attention_and_lse_and_softmax));
380+
return std::tie(attention, std::get<1>(attention_and_lse_and_softmax));
383381
}
384382

385383
std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_nestedtensor_cuda(
@@ -496,7 +494,6 @@ Tensor flash_attention_helper(
496494
const Tensor& key,
497495
const Tensor& value,
498496
double dropout_p,
499-
bool need_atten_weights,
500497
bool is_causal) {
501498
// Query is of size (batch_size x ragged_seq_len x (3 or 1) x n_heads x
502499
// head_did
@@ -541,7 +538,6 @@ Tensor flash_attention_helper(
541538
cumulative_sequence_length_q,
542539
max_seqlen_batch_q,
543540
max_seqlen_batch_q,
544-
false /*return_softmax*/,
545541
dropout_p,
546542
is_causal));
547543
// Output of flash_attention is a regular tensor lets wrap it back up to

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -658,10 +658,30 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> native_decoder_only_multi_head_attent
658658
}
659659

660660
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
661-
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
661+
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal){
662662
return static_cast<int64_t>(sdp::SDPBackend::math);
663663
}
664664

665+
// !!!!!! TODO: THIS NEEDS TO BE REMOVED BUT PEOPLE HAVE TRAINED THEIR MODELS
666+
// WITH THIS OP BUILTIN !!!!!!
667+
std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
668+
const Tensor& query_,
669+
const Tensor& key,
670+
const Tensor& value,
671+
const c10::optional<Tensor>& attn_mask_,
672+
double dropout_p,
673+
bool need_attn_weights,
674+
bool is_causal) {
675+
if (!need_attn_weights) {
676+
return std::make_tuple(
677+
at::scaled_dot_product_attention(
678+
query_, key, value, attn_mask_, dropout_p, is_causal),
679+
Tensor());
680+
}
681+
return at::_scaled_dot_product_attention_math(
682+
query_, key, value, attn_mask_, dropout_p, is_causal);
683+
}
684+
665685
// Computes scaled dot product attention on query, key and value tensors, using
666686
// an optional attention mask if passed, and applying dropout if a probability
667687
// greater than 0.0 is specified.
@@ -690,69 +710,52 @@ int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Ten
690710
// S: Source sequence length
691711
// L: Target sequence length
692712
// E: Embedding dimension
693-
std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
713+
Tensor scaled_dot_product_attention(
694714
const Tensor& query_,
695715
const Tensor& key,
696716
const Tensor& value,
697717
const c10::optional<Tensor>& attn_mask_,
698718
double dropout_p,
699-
bool need_attn_weights,
700719
bool is_causal) {
701-
// TODO: The second return is the attention weights if the math kernel is
702-
// used. The fused kernels do not return this Tensor so for the fused kernels
703-
// The second return SHOULD always be an empty Tensor, unless need_attn_weights
704-
// is true (in which case the fused kernels would not be called). This blows up
705-
// op_info tests.
706720
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
707721
if (query_.device().type() == DeviceType::CUDA){
708722
choice_int = _fused_sdp_choice_stub(query_.device().type(),
709-
query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
723+
query_, key, value, attn_mask_, dropout_p, is_causal);
710724
}
711725
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
712726
switch (backend) {
713727
case sdp::SDPBackend::flash_attention: {
714728
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
715-
query_, key, value, dropout_p, need_attn_weights, is_causal);
716-
return std::make_tuple(
717-
std::move(std::get<0>(out_lse_softmax)),
718-
std::move(std::get<2>(out_lse_softmax)));
729+
query_, key, value, dropout_p, is_causal);
730+
return std::get<0>(out_lse_softmax);
719731
}
720732
case sdp::SDPBackend::efficient_attention: {
721733
bool compute_logsumexp =
722734
(query_.requires_grad() || key.requires_grad() ||
723735
value.requires_grad());
724736
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
725737
query_, key, value, compute_logsumexp, is_causal);
726-
// We need to make an empty tensor in the shape of attention weights
727-
// for the sake of meta tensors.
728-
if (query_.is_nested()) {
729-
// TODO: Need to fix when we have empty for nested tensors.
730-
return out_and_lse;
731-
}
732-
return std::make_tuple(
733-
std::move(std::get<0>(out_and_lse)),
734-
at::empty_symint({0}, query_.options()));
738+
return std::get<0>(out_and_lse);
735739
}
736740
case sdp::SDPBackend::math:
737-
return at::_scaled_dot_product_attention_math(
741+
return std::get<0>(at::_scaled_dot_product_attention_math(
738742
query_,
739743
key,
740744
value,
741745
attn_mask_,
742746
dropout_p,
743-
need_attn_weights,
744-
is_causal);
747+
is_causal));
745748
default:
746749
TORCH_CHECK(
747750
false,
748751
"No viable backend for scaled_dot_product_attention was found.");
749-
return std::make_tuple(Tensor(), Tensor());
752+
return Tensor();
750753
}
751754
}
752755

753756
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
754757
const Tensor& query_, const Tensor& key, const Tensor& value,
755-
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
758+
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal) {
756759
C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
757760
if (query_.is_nested() || key.is_nested() || value.is_nested()) {
758761
TORCH_CHECK(
@@ -797,13 +800,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
797800
if (dropout_p > 0.0) {
798801
attn = at::dropout(attn, dropout_p, true);
799802
}
800-
const auto output = at::matmul(attn, value);
801-
// If you don't need it then you don't get it.
802-
// TODO: Need to fix when we have empty for nested tensors.
803-
attn = need_attn_weights || query_.is_nested()
804-
? attn
805-
: at::empty_symint({0}, query_.options());
806-
return std::make_tuple(output, attn);
803+
return std::make_tuple(at::matmul(attn, value), attn);
807804
}
808805

809806
Tensor triton_multi_head_attention(

aten/src/ATen/native/transformers/attention.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace at {
88
namespace native {
99

1010
using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
11-
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal);
11+
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal);
1212

1313
DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);
1414

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
557557

558558
#endif
559559
const auto dim_per_head = D / num_head;
560-
if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 ) {
560+
if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 && !need_weights) {
561561

562562
// We have not done linear projection yet but the input for SDP
563563
// Is expected to be 4 dimensional. We "cheaply" create view tensors
@@ -566,7 +566,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
566566
auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
567567
auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
568568

569-
sdp::sdp_params kernel_params{q, k, v, mask.has_value(), 0.0, need_weights, false};
569+
sdp::sdp_params kernel_params{q, k, v, mask.has_value(), 0.0, false};
570570
auto backend = select_sdp_backend(kernel_params);
571571
if (backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention) {
572572
auto x = at::linear(query, qkv_weight, qkv_bias);
@@ -580,10 +580,9 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
580580
chunks[2] = (chunks[2].view({x_size_0, -1, num_head, dim_per_head}))
581581
.transpose(1, 2);
582582

583-
auto y = at::_scaled_dot_product_attention(
584-
chunks[0], chunks[1], chunks[2], mask, 0.0, need_weights, false);
585-
auto past_sdp =
586-
std::get<0>(y).transpose(1, 2).reshape({x_size_0, -1, embed_dim});
583+
auto y = at::scaled_dot_product_attention(
584+
chunks[0], chunks[1], chunks[2], mask, 0.0, false);
585+
auto past_sdp = y.transpose(1, 2).reshape({x_size_0, -1, embed_dim});
587586
return std::make_tuple(
588587
at::linear(past_sdp, proj_weight, proj_bias), Tensor());
589588
}
@@ -680,12 +679,11 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
680679
return std::make_tuple(std::move(proj), std::move(qkt));
681680
}
682681

683-
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
682+
std::tuple<Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
684683
const Tensor& query,
685684
const Tensor& key,
686685
const Tensor& value,
687686
double dropout_p,
688-
bool return_softmax,
689687
bool is_causal) {
690688
// Used for tracking usage statistics
691689
C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention");
@@ -730,8 +728,8 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
730728
Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim});
731729
Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim});
732730

733-
Tensor attention, log_sumexp, softmax;
734-
std::tie(attention, log_sumexp, softmax) =
731+
Tensor attention, log_sumexp;
732+
std::tie(attention, log_sumexp) =
735733
at::_flash_attention_forward(
736734
query_reshaped,
737735
key_reshaped,
@@ -740,14 +738,13 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
740738
cumulative_sequence_length_k,
741739
max_seqlen_batch_q,
742740
max_seqlen_batch_k,
743-
return_softmax,
744741
dropout_p,
745742
is_causal);
746743
// Reshape output to convert nnz to batch_size and seq_len
747744
attention =
748745
attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2);
749746

750-
return std::make_tuple(attention, log_sumexp, softmax);
747+
return std::make_tuple(attention, log_sumexp);
751748
}
752749

753750
std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
@@ -780,8 +777,8 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
780777
}
781778

782779
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
783-
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
784-
sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal};
780+
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal){
781+
sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, is_causal};
785782
auto backend = select_sdp_backend(kernel_params);
786783
if (backend == sdp::SDPBackend::error) {
787784
TORCH_CHECK(
@@ -809,15 +806,14 @@ bool _chunk_grad_outputs_efficient_attention(
809806
}
810807

811808

812-
std::tuple<Tensor, Tensor, Tensor> _flash_attention_forward(
809+
std::tuple<Tensor, Tensor> _flash_attention_forward(
813810
const Tensor& query,
814811
const Tensor& key,
815812
const Tensor& value,
816813
const Tensor& cumulative_sequence_length_q,
817814
const Tensor& cumulative_sequence_length_k,
818815
const int64_t max_seqlen_batch_q,
819816
const int64_t max_seqlen_batch_k,
820-
bool return_softmax,
821817
double dropout_p,
822818
bool is_causal) {
823819
#if defined(USE_FLASH_ATTENTION)
@@ -832,13 +828,12 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_forward(
832828
max_seqlen_batch_k,
833829
dropout_p,
834830
softmax_scale,
835-
false,
831+
false, /*zero_tensors = false for all calls here*/
836832
is_causal,
837-
return_softmax,
838833
c10::nullopt);
839834
#endif
840835
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
841-
return std::make_tuple(Tensor(), Tensor(), Tensor());
836+
return std::make_tuple(Tensor(), Tensor());
842837
}
843838

844839
std::tuple<at::Tensor, at::Tensor> _efficient_attention_forward(

0 commit comments

Comments
 (0)