1616#include < functional>
1717#include < unordered_set>
1818#include < vector>
19+ #include < cmath>
1920
2021namespace sdp {
2122
@@ -29,6 +30,46 @@ struct sdp_params {
2930 bool is_causal;
3031};
3132
33+ inline std::array<SDPBackend, num_backends> priority_order (sdp_params params) {
34+ constexpr std::array<SDPBackend, num_backends> default_order{
35+ SDPBackend::flash_attention,
36+ SDPBackend::efficient_attention,
37+ SDPBackend::math};
38+ // Logic is taken from xformers
39+ // FlashAttention parallelizes across "batch_size * num_heads"
40+ // MemEff parallelizes across "batch_size * num_heads * num_queries" and can
41+ // be more efficient. batch_size, q_len, num_heads, k = inp.query.shape
42+ if (params.query .is_nested ()) {
43+ // See check_for_nested_inputs for details
44+ return {
45+ SDPBackend::efficient_attention,
46+ SDPBackend::flash_attention,
47+ SDPBackend::math};
48+ }
49+ const auto sizes = params.query .sizes ();
50+ if (params.query .dim () != 4 ) {
51+ return default_order;
52+ }
53+ const auto batch_size{sizes[0 ]}, num_heads{sizes[1 ]}, query_lengths{sizes[2 ]},
54+ head_dim{sizes[3 ]};
55+ if (batch_size > 0 ) {
56+ const int64_t threads_flash = batch_size * num_heads;
57+ const int64_t threads_cutlass =
58+ threads_flash * (int64_t )std::floor (query_lengths / 64 );
59+ bool more_threads_cutlass =
60+ (int64_t )std::floor (threads_cutlass / 2 ) >= threads_flash;
61+ bool small_threads_flash = threads_flash < 60 ;
62+ bool large_head_dim = std::max (head_dim, params.key .sizes ()[3 ]) == 128 ;
63+ if ((small_threads_flash && more_threads_cutlass) || large_head_dim) {
64+ return {
65+ SDPBackend::efficient_attention,
66+ SDPBackend::flash_attention,
67+ SDPBackend::math};
68+ }
69+ }
70+ return default_order;
71+ }
72+
3273template <typename dtype_vector>
3374inline bool check_tensor_dtype (
3475 sdp_params params,
@@ -147,7 +188,7 @@ inline bool check_tensor_shapes(sdp_params params, bool debug) {
147188 (query_dim == 4 ))) {
148189 if (debug) {
149190 TORCH_WARN (
150- " Flash attention requires query, key and value to be 4 dimensional, but got Query dim: " ,
191+ " Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: " ,
151192 query_dim,
152193 " , Key dim: " ,
153194 params.key .dim (),
@@ -368,23 +409,38 @@ inline SDPBackend select_sdp_backend(sdp_params kernel_params) {
368409 if (!ctx.userEnabledMathSDP () && !ctx.userEnabledFlashSDP () && !ctx.userEnabledMemEfficientSDP ()) {
369410 return SDPBackend::error;
370411 }
412+ // Get ideal kernel ordering
413+ const auto ordering = priority_order (kernel_params);
414+
371415 // Because TORCHCHECK checks if condition is true we negate debug so that
372416 // The statements will be printed when debug is true
373417 bool print_debug = false ;
374- if (use_flash_attention (kernel_params, print_debug)) {
375- return SDPBackend::flash_attention;
376- }
377- if (use_mem_efficient_attention (kernel_params, print_debug)) {
378- return SDPBackend::efficient_attention;
379- }
380- if (ctx.userEnabledMathSDP ()) {
381- return SDPBackend::math;
418+ for (auto & backend : ordering) {
419+ switch (backend) {
420+ case SDPBackend::flash_attention:
421+ if (use_flash_attention (kernel_params, print_debug)) {
422+ return SDPBackend::flash_attention;
423+ }
424+ break ;
425+ case SDPBackend::efficient_attention:
426+ if (use_mem_efficient_attention (kernel_params, print_debug)) {
427+ return SDPBackend::efficient_attention;
428+ }
429+ break ;
430+ case SDPBackend::math:
431+ if (ctx.userEnabledMathSDP ()) {
432+ return SDPBackend::math;
433+ }
434+ break ;
435+ default :
436+ TORCH_CHECK (false , " Invalid backend" );
437+ }
382438 }
383439 // If we have gotten to this point then two things have happened:
384440 // 1. use_flash_attention or use_mem_efficient did not satisfy the
385441 // constraints to be ran
386442 // 2. The user has explicitly disabled the math kernel
387- // We then re-run use_flash_attention with debug enabled to print out the
443+ // We then re-run the kernel checks with debug enabled to print out the
388444 // reason why the kernel was not selected
389445
390446 print_debug = true ;
0 commit comments