Skip to content

Commit 3a60deb

Browse files
drisspgpytorchmergebot
authored andcommitted
implement ordering (#91362)
# Summary In some cases, dependent on input, flash-attention is not the fastest fused kernel and memory-efficient attention is better. This implements a simple heuristic function for deciding the ordering of kernel functions. This was based off of the xformer function found here: https://github.com/fairinternal/xformers/blob/15bff4986c3a4376176a4e6fa3dc0f2a120fa0bb/xformers/ops/fmha/dispatch.py#L13 Pull Request resolved: #91362 Approved by: https://github.com/cpuhrsch
1 parent 743c385 commit 3a60deb

File tree

2 files changed

+69
-10
lines changed

2 files changed

+69
-10
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.h

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <functional>
1717
#include <unordered_set>
1818
#include <vector>
19+
#include <cmath>
1920

2021
namespace sdp {
2122

@@ -29,6 +30,46 @@ struct sdp_params {
2930
bool is_causal;
3031
};
3132

33+
inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
34+
constexpr std::array<SDPBackend, num_backends> default_order{
35+
SDPBackend::flash_attention,
36+
SDPBackend::efficient_attention,
37+
SDPBackend::math};
38+
// Logic is taken from xformers
39+
// FlashAttention parallelizes across "batch_size * num_heads"
40+
// MemEff parallelizes across "batch_size * num_heads * num_queries" and can
41+
// be more efficient. batch_size, q_len, num_heads, k = inp.query.shape
42+
if (params.query.is_nested()) {
43+
// See check_for_nested_inputs for details
44+
return {
45+
SDPBackend::efficient_attention,
46+
SDPBackend::flash_attention,
47+
SDPBackend::math};
48+
}
49+
const auto sizes = params.query.sizes();
50+
if (params.query.dim() != 4) {
51+
return default_order;
52+
}
53+
const auto batch_size{sizes[0]}, num_heads{sizes[1]}, query_lengths{sizes[2]},
54+
head_dim{sizes[3]};
55+
if (batch_size > 0) {
56+
const int64_t threads_flash = batch_size * num_heads;
57+
const int64_t threads_cutlass =
58+
threads_flash * (int64_t)std::floor(query_lengths / 64);
59+
bool more_threads_cutlass =
60+
(int64_t)std::floor(threads_cutlass / 2) >= threads_flash;
61+
bool small_threads_flash = threads_flash < 60;
62+
bool large_head_dim = std::max(head_dim, params.key.sizes()[3]) == 128;
63+
if ((small_threads_flash && more_threads_cutlass) || large_head_dim) {
64+
return {
65+
SDPBackend::efficient_attention,
66+
SDPBackend::flash_attention,
67+
SDPBackend::math};
68+
}
69+
}
70+
return default_order;
71+
}
72+
3273
template <typename dtype_vector>
3374
inline bool check_tensor_dtype(
3475
sdp_params params,
@@ -147,7 +188,7 @@ inline bool check_tensor_shapes(sdp_params params, bool debug) {
147188
(query_dim == 4 ))) {
148189
if (debug) {
149190
TORCH_WARN(
150-
"Flash attention requires query, key and value to be 4 dimensional, but got Query dim: ",
191+
"Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
151192
query_dim,
152193
", Key dim: ",
153194
params.key.dim(),
@@ -368,23 +409,38 @@ inline SDPBackend select_sdp_backend(sdp_params kernel_params) {
368409
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP() && !ctx.userEnabledMemEfficientSDP()) {
369410
return SDPBackend::error;
370411
}
412+
// Get ideal kernel ordering
413+
const auto ordering = priority_order(kernel_params);
414+
371415
// Because TORCHCHECK checks if condition is true we negate debug so that
372416
// The statements will be printed when debug is true
373417
bool print_debug = false;
374-
if (use_flash_attention(kernel_params, print_debug)) {
375-
return SDPBackend::flash_attention;
376-
}
377-
if (use_mem_efficient_attention(kernel_params, print_debug)) {
378-
return SDPBackend::efficient_attention;
379-
}
380-
if (ctx.userEnabledMathSDP()) {
381-
return SDPBackend::math;
418+
for (auto& backend : ordering) {
419+
switch (backend) {
420+
case SDPBackend::flash_attention:
421+
if (use_flash_attention(kernel_params, print_debug)) {
422+
return SDPBackend::flash_attention;
423+
}
424+
break;
425+
case SDPBackend::efficient_attention:
426+
if (use_mem_efficient_attention(kernel_params, print_debug)) {
427+
return SDPBackend::efficient_attention;
428+
}
429+
break;
430+
case SDPBackend::math:
431+
if (ctx.userEnabledMathSDP()) {
432+
return SDPBackend::math;
433+
}
434+
break;
435+
default:
436+
TORCH_CHECK(false, "Invalid backend");
437+
}
382438
}
383439
// If we have gotten to this point then two things have happened:
384440
// 1. use_flash_attention or use_mem_efficient did not satisfy the
385441
// constraints to be ran
386442
// 2. The user has explicitly disabled the math kernel
387-
// We then re-run use_flash_attention with debug enabled to print out the
443+
// We then re-run the kernel checks with debug enabled to print out the
388444
// reason why the kernel was not selected
389445

390446
print_debug = true;

aten/src/ATen/native/transformers/sdp_utils_cpp.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
2+
#include <cstdint>
23
namespace sdp {
4+
5+
constexpr int32_t num_backends = 3;
36
enum class SDPBackend {
47
error = -1,
58
math = 0,

0 commit comments

Comments
 (0)