Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(rnn_tanh_cell, lower_precision_fp)
KERNEL(rnn_relu_cell, lower_precision_fp)
KERNEL(_scaled_dot_product_flash_attention, lower_precision_fp)
KERNEL(_scaled_dot_product_attention, lower_precision_fp)
KERNEL(scaled_dot_product_attention, lower_precision_fp)

// fp32
KERNEL(acos, fp32)
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(rrelu);
OP_DECOMPOSE(prelu);
OP_DECOMPOSE2(softmax, int);
OP_DECOMPOSE(scaled_dot_product_attention);
OP_DECOMPOSE(special_gammainc);
OP_DECOMPOSE(special_gammaincc);
OP_DECOMPOSE(special_logit);
Expand Down
14 changes: 10 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13943,21 +13943,27 @@
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
autogen: _native_multi_head_attention.out

# TODO: THIS NEEDS TO BE REMOVED BUT PEOPLE HAVE TRAINED THEIR MODELS WITH THIS OP BUILTIN
- func: _scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch added this back in since your review, appears some models may have been packaged with this builtin aten op

python_module: nn
variants: function
autogen: _scaled_dot_product_attention.out

- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> Tensor
python_module: nn
variants: function
autogen: scaled_dot_product_attention.out

# This aten function is kept so that we can test the choice function from Python
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> int
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> int
dispatch:
CPU, NestedTensorCPU, Meta: _fused_sdp_choice_cpp
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda

- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> (Tensor, Tensor)
variants: function

- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool return_softmax=False, bool is_causal=False) -> (Tensor, Tensor, Tensor)
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False) -> (Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_flash_attention_cuda
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
Expand All @@ -13975,7 +13981,7 @@
dispatch:
CUDA: _chunk_grad_outputs_efficient_attention
# Returns ouput, softmax_logsumexp, softmax
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, bool return_softmax, float dropout_p, bool is_causal) -> (Tensor, Tensor, Tensor)
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal) -> (Tensor, Tensor)
variants: function
dispatch:
CUDA: _flash_attention_forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,11 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {

} // namespace

std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_nestedtensor_cuda(
std::tuple<Tensor, Tensor> _scaled_dot_product_flash_attention_nestedtensor_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
double dropout_p,
bool return_softmax,
bool is_causal) {
TORCH_CHECK(false, "There are currently cuda memory errors being returned from this path.")
// Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
Expand Down Expand Up @@ -373,13 +372,12 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_nestedten
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
return_softmax,
dropout_p,
is_causal);
// Reshape output to convert nnz to batch_size and seq_len
Tensor attention = std::get<0>(attention_and_lse_and_softmax);
attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()).transpose(1,2);
return std::tie(attention, std::get<1>(attention_and_lse_and_softmax), std::get<2>(attention_and_lse_and_softmax));
return std::tie(attention, std::get<1>(attention_and_lse_and_softmax));
}

std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_nestedtensor_cuda(
Expand Down Expand Up @@ -496,7 +494,6 @@ Tensor flash_attention_helper(
const Tensor& key,
const Tensor& value,
double dropout_p,
bool need_atten_weights,
bool is_causal) {
// Query is of size (batch_size x ragged_seq_len x (3 or 1) x n_heads x
// head_did
Expand Down Expand Up @@ -541,7 +538,6 @@ Tensor flash_attention_helper(
cumulative_sequence_length_q,
max_seqlen_batch_q,
max_seqlen_batch_q,
false /*return_softmax*/,
dropout_p,
is_causal));
// Output of flash_attention is a regular tensor lets wrap it back up to
Expand Down
65 changes: 31 additions & 34 deletions aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,30 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> native_decoder_only_multi_head_attent
}

int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal){
return static_cast<int64_t>(sdp::SDPBackend::math);
}

// !!!!!! TODO: THIS NEEDS TO BE REMOVED BUT PEOPLE HAVE TRAINED THEIR MODELS
// WITH THIS OP BUILTIN !!!!!!
std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
const Tensor& query_,
const Tensor& key,
const Tensor& value,
const c10::optional<Tensor>& attn_mask_,
double dropout_p,
bool need_attn_weights,
bool is_causal) {
if (!need_attn_weights) {
return std::make_tuple(
at::scaled_dot_product_attention(
query_, key, value, attn_mask_, dropout_p, is_causal),
Tensor());
}
return at::_scaled_dot_product_attention_math(
query_, key, value, attn_mask_, dropout_p, is_causal);
}

// Computes scaled dot product attention on query, key and value tensors, using
// an optional attention mask if passed, and applying dropout if a probability
// greater than 0.0 is specified.
Expand Down Expand Up @@ -690,69 +710,52 @@ int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Ten
// S: Source sequence length
// L: Target sequence length
// E: Embedding dimension
std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
Tensor scaled_dot_product_attention(
const Tensor& query_,
const Tensor& key,
const Tensor& value,
const c10::optional<Tensor>& attn_mask_,
double dropout_p,
bool need_attn_weights,
bool is_causal) {
// TODO: The second return is the attention weights if the math kernel is
// used. The fused kernels do not return this Tensor so for the fused kernels
// The second return SHOULD always be an empty Tensor, unless need_attn_weights
// is true (in which case the fused kernels would not be called). This blows up
// op_info tests.
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
if (query_.device().type() == DeviceType::CUDA){
choice_int = _fused_sdp_choice_stub(query_.device().type(),
query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
query_, key, value, attn_mask_, dropout_p, is_causal);
}
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
switch (backend) {
case sdp::SDPBackend::flash_attention: {
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
query_, key, value, dropout_p, need_attn_weights, is_causal);
return std::make_tuple(
std::move(std::get<0>(out_lse_softmax)),
std::move(std::get<2>(out_lse_softmax)));
query_, key, value, dropout_p, is_causal);
return std::get<0>(out_lse_softmax);
}
case sdp::SDPBackend::efficient_attention: {
bool compute_logsumexp =
(query_.requires_grad() || key.requires_grad() ||
value.requires_grad());
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
query_, key, value, compute_logsumexp, is_causal);
// We need to make an empty tensor in the shape of attention weights
// for the sake of meta tensors.
if (query_.is_nested()) {
// TODO: Need to fix when we have empty for nested tensors.
return out_and_lse;
}
return std::make_tuple(
std::move(std::get<0>(out_and_lse)),
at::empty_symint({0}, query_.options()));
return std::get<0>(out_and_lse);
}
case sdp::SDPBackend::math:
return at::_scaled_dot_product_attention_math(
return std::get<0>(at::_scaled_dot_product_attention_math(
query_,
key,
value,
attn_mask_,
dropout_p,
need_attn_weights,
is_causal);
is_causal));
default:
TORCH_CHECK(
false,
"No viable backend for scaled_dot_product_attention was found.");
return std::make_tuple(Tensor(), Tensor());
return Tensor();
}
}

std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal) {
C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
if (query_.is_nested() || key.is_nested() || value.is_nested()) {
TORCH_CHECK(
Expand Down Expand Up @@ -797,13 +800,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
if (dropout_p > 0.0) {
attn = at::dropout(attn, dropout_p, true);
}
const auto output = at::matmul(attn, value);
// If you don't need it then you don't get it.
// TODO: Need to fix when we have empty for nested tensors.
attn = need_attn_weights || query_.is_nested()
? attn
: at::empty_symint({0}, query_.options());
return std::make_tuple(output, attn);
return std::make_tuple(at::matmul(attn, value), attn);
}

Tensor triton_multi_head_attention(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/transformers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace at {
namespace native {

using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal);
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal);

DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);

Expand Down
33 changes: 14 additions & 19 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(

#endif
const auto dim_per_head = D / num_head;
if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 ) {
if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 && !need_weights) {

// We have not done linear projection yet but the input for SDP
// Is expected to be 4 dimensional. We "cheaply" create view tensors
Expand All @@ -566,7 +566,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2);

sdp::sdp_params kernel_params{q, k, v, mask.has_value(), 0.0, need_weights, false};
sdp::sdp_params kernel_params{q, k, v, mask.has_value(), 0.0, false};
auto backend = select_sdp_backend(kernel_params);
if (backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention) {
auto x = at::linear(query, qkv_weight, qkv_bias);
Expand All @@ -580,10 +580,9 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
chunks[2] = (chunks[2].view({x_size_0, -1, num_head, dim_per_head}))
.transpose(1, 2);

auto y = at::_scaled_dot_product_attention(
chunks[0], chunks[1], chunks[2], mask, 0.0, need_weights, false);
auto past_sdp =
std::get<0>(y).transpose(1, 2).reshape({x_size_0, -1, embed_dim});
auto y = at::scaled_dot_product_attention(
chunks[0], chunks[1], chunks[2], mask, 0.0, false);
auto past_sdp = y.transpose(1, 2).reshape({x_size_0, -1, embed_dim});
return std::make_tuple(
at::linear(past_sdp, proj_weight, proj_bias), Tensor());
}
Expand Down Expand Up @@ -680,12 +679,11 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
return std::make_tuple(std::move(proj), std::move(qkt));
}

std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
std::tuple<Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
double dropout_p,
bool return_softmax,
bool is_causal) {
// Used for tracking usage statistics
C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention");
Expand Down Expand Up @@ -730,8 +728,8 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim});
Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim});

Tensor attention, log_sumexp, softmax;
std::tie(attention, log_sumexp, softmax) =
Tensor attention, log_sumexp;
std::tie(attention, log_sumexp) =
at::_flash_attention_forward(
query_reshaped,
key_reshaped,
Expand All @@ -740,14 +738,13 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
return_softmax,
dropout_p,
is_causal);
// Reshape output to convert nnz to batch_size and seq_len
attention =
attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2);

return std::make_tuple(attention, log_sumexp, softmax);
return std::make_tuple(attention, log_sumexp);
}

std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
Expand Down Expand Up @@ -780,8 +777,8 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
}

int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal};
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, is_causal};
auto backend = select_sdp_backend(kernel_params);
if (backend == sdp::SDPBackend::error) {
TORCH_CHECK(
Expand Down Expand Up @@ -809,15 +806,14 @@ bool _chunk_grad_outputs_efficient_attention(
}


std::tuple<Tensor, Tensor, Tensor> _flash_attention_forward(
std::tuple<Tensor, Tensor> _flash_attention_forward(
const Tensor& query,
const Tensor& key,
const Tensor& value,
const Tensor& cumulative_sequence_length_q,
const Tensor& cumulative_sequence_length_k,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_k,
bool return_softmax,
double dropout_p,
bool is_causal) {
#if defined(USE_FLASH_ATTENTION)
Expand All @@ -832,13 +828,12 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_forward(
max_seqlen_batch_k,
dropout_p,
softmax_scale,
false,
false, /*zero_tensors = false for all calls here*/
is_causal,
return_softmax,
c10::nullopt);
#endif
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
return std::make_tuple(Tensor(), Tensor(), Tensor());
return std::make_tuple(Tensor(), Tensor());
}

std::tuple<at::Tensor, at::Tensor> _efficient_attention_forward(
Expand Down
Loading