Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <functional>
#include <unordered_set>
#include <vector>
#include <cmath>

namespace sdp {

Expand All @@ -29,6 +30,46 @@ struct sdp_params {
bool is_causal;
};

inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to implement this (and I think it's kind of what this is), is to modify use_flash_attention and use_mem_efficient_attention to return an integer or to create new functions that return integers.

These integers then are the estimated number of operations performed by the respective fused kernel. This is similar to estimate_matmul_time.

You then pick the one that returns the lowest number of operations. And if the number of operations is negative, well then the kernel doesn't apply.

constexpr std::array<SDPBackend, num_backends> default_order{
SDPBackend::flash_attention,
SDPBackend::efficient_attention,
SDPBackend::math};
// Logic is taken from xformers
// FlashAttention parallelizes across "batch_size * num_heads"
// MemEff parallelizes across "batch_size * num_heads * num_queries" and can
// be more efficient. batch_size, q_len, num_heads, k = inp.query.shape
if (params.query.is_nested()) {
// See check_for_nested_inputs for details
return {
SDPBackend::efficient_attention,
SDPBackend::flash_attention,
SDPBackend::math};
}
const auto sizes = params.query.sizes();
if (params.query.dim() != 4) {
return default_order;
}
const auto batch_size{sizes[0]}, num_heads{sizes[1]}, query_lengths{sizes[2]},
head_dim{sizes[3]};
if (batch_size > 0) {
const int64_t threads_flash = batch_size * num_heads;
const int64_t threads_cutlass =
threads_flash * (int64_t)std::floor(query_lengths / 64);
bool more_threads_cutlass =
(int64_t)std::floor(threads_cutlass / 2) >= threads_flash;
bool small_threads_flash = threads_flash < 60;
bool large_head_dim = std::max(head_dim, params.key.sizes()[3]) == 128;
if ((small_threads_flash && more_threads_cutlass) || large_head_dim) {
return {
SDPBackend::efficient_attention,
SDPBackend::flash_attention,
SDPBackend::math};
}
}
return default_order;
}

template <typename dtype_vector>
inline bool check_tensor_dtype(
sdp_params params,
Expand Down Expand Up @@ -147,7 +188,7 @@ inline bool check_tensor_shapes(sdp_params params, bool debug) {
(query_dim == 4 ))) {
if (debug) {
TORCH_WARN(
"Flash attention requires query, key and value to be 4 dimensional, but got Query dim: ",
"Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
query_dim,
", Key dim: ",
params.key.dim(),
Expand Down Expand Up @@ -368,23 +409,38 @@ inline SDPBackend select_sdp_backend(sdp_params kernel_params) {
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP() && !ctx.userEnabledMemEfficientSDP()) {
return SDPBackend::error;
}
// Get ideal kernel ordering
const auto ordering = priority_order(kernel_params);

// Because TORCHCHECK checks if condition is true we negate debug so that
// The statements will be printed when debug is true
bool print_debug = false;
if (use_flash_attention(kernel_params, print_debug)) {
return SDPBackend::flash_attention;
}
if (use_mem_efficient_attention(kernel_params, print_debug)) {
return SDPBackend::efficient_attention;
}
if (ctx.userEnabledMathSDP()) {
return SDPBackend::math;
for (auto& backend : ordering) {
switch (backend) {
case SDPBackend::flash_attention:
if (use_flash_attention(kernel_params, print_debug)) {
return SDPBackend::flash_attention;
}
break;
case SDPBackend::efficient_attention:
if (use_mem_efficient_attention(kernel_params, print_debug)) {
return SDPBackend::efficient_attention;
}
break;
case SDPBackend::math:
if (ctx.userEnabledMathSDP()) {
return SDPBackend::math;
}
break;
default:
TORCH_CHECK(false, "Invalid backend");
}
}
// If we have gotten to this point then two things have happened:
// 1. use_flash_attention or use_mem_efficient did not satisfy the
// constraints to be ran
// 2. The user has explicitly disabled the math kernel
// We then re-run use_flash_attention with debug enabled to print out the
// We then re-run the kernel checks with debug enabled to print out the
// reason why the kernel was not selected

print_debug = true;
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/transformers/sdp_utils_cpp.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once
#include <cstdint>
namespace sdp {

constexpr int32_t num_backends = 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a disturbingly generic name even if it's within the sdp namespace.

enum class SDPBackend {
error = -1,
math = 0,
Expand Down