Skip to content

Commit 2b2247e

Browse files
committed
create nested spcecific aten function
1 parent 758735b commit 2b2247e

File tree

12 files changed

+57
-68
lines changed

12 files changed

+57
-68
lines changed

aten/src/ATen/native/cpu/FlashAttentionKernel.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ void cpu_flash_attention(
5555
const Tensor& logsumexp,
5656
const Tensor& cum_seq_q,
5757
const Tensor& cum_seq_k,
58-
int64_t& max_q,
59-
int64_t& max_k,
6058
const Tensor& philox_seed,
6159
const Tensor& philox_offset,
6260
const Tensor& debug_attn_mask,
@@ -279,8 +277,6 @@ void cpu_flash_attention_backward(
279277
const at::Tensor& logsumexp,
280278
const Tensor& cumulative_sequence_length_q,
281279
const Tensor& cumulative_sequence_length_k,
282-
const int64_t max_seqlen_batch_q,
283-
const int64_t max_seqlen_batch_k,
284280
double dropout_p,
285281
bool is_causal,
286282
const at::Tensor& philox_seed,
@@ -540,8 +536,6 @@ void flash_attention_kernel_impl(
540536
const Tensor& logsumexp,
541537
const Tensor& cum_seq_q,
542538
const Tensor& cum_seq_k,
543-
int64_t& max_q,
544-
int64_t& max_k,
545539
const Tensor& philox_seed,
546540
const Tensor& philox_offset,
547541
const Tensor& debug_attn_mask,
@@ -558,17 +552,17 @@ void flash_attention_kernel_impl(
558552
if (q_seq_len >= 768) {
559553
cpu_flash_attention<scalar_t, 256, 512>(
560554
output, logsumexp, cum_seq_q, cum_seq_k,
561-
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
555+
philox_seed, philox_offset, debug_attn_mask,
562556
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
563557
} else if (q_seq_len >= 192) {
564558
cpu_flash_attention<scalar_t, 64, 512>(
565559
output, logsumexp, cum_seq_q, cum_seq_k,
566-
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
560+
philox_seed, philox_offset, debug_attn_mask,
567561
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
568562
} else {
569563
cpu_flash_attention<scalar_t, 32, 512>(
570564
output, logsumexp, cum_seq_q, cum_seq_k,
571-
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
565+
philox_seed, philox_offset, debug_attn_mask,
572566
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
573567
}
574568
});
@@ -586,8 +580,6 @@ void flash_attention_backward_kernel_impl(
586580
const at::Tensor& logsumexp,
587581
const Tensor& cum_seq_q,
588582
const Tensor& cum_seq_k,
589-
const int64_t max_q,
590-
const int64_t max_k,
591583
double dropout_p,
592584
bool is_causal,
593585
const at::Tensor& philox_seed,
@@ -604,19 +596,19 @@ void flash_attention_backward_kernel_impl(
604596
cpu_flash_attention_backward<scalar_t, 256, 512>(
605597
grad_q, grad_k, grad_v, grad_out_contig,
606598
query, key, value, out, logsumexp,
607-
cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
599+
cum_seq_q, cum_seq_k, dropout_p,
608600
is_causal, philox_seed, philox_offset, scale);
609601
} else if (q_seq_len >= 192) {
610602
cpu_flash_attention_backward<scalar_t, 64, 512>(
611603
grad_q, grad_k, grad_v, grad_out_contig,
612604
query, key, value, out, logsumexp,
613-
cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
605+
cum_seq_q, cum_seq_k, dropout_p,
614606
is_causal, philox_seed, philox_offset, scale);
615607
} else {
616608
cpu_flash_attention_backward<scalar_t, 32, 512>(
617609
grad_q, grad_k, grad_v, grad_out_contig,
618610
query, key, value, out, logsumexp,
619-
cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
611+
cum_seq_q, cum_seq_k, dropout_p,
620612
is_causal, philox_seed, philox_offset, scale);
621613
}
622614
});

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14349,14 +14349,18 @@
1434914349
variants: function
1435014350
tags: nondeterministic_seeded
1435114351

14352-
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14352+
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
1435314353
dispatch:
1435414354
CPU: _scaled_dot_product_flash_attention_cpu
1435514355
CUDA: _scaled_dot_product_flash_attention_cuda
14356+
tags: nondeterministic_seeded
14357+
14358+
- func: _scaled_dot_product_flash_attention_nested(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14359+
dispatch:
1435614360
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
1435714361
tags: nondeterministic_seeded
1435814362

14359-
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
14363+
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
1436014364
device_check: NoCheck
1436114365
variants: function
1436214366
dispatch:

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,10 @@ at::Tensor post_process_flash_output(
561561
return out;
562562
}
563563

564+
bool has_nested_inputs(const Tensor& query, const Tensor& key, const Tensor& value){
565+
return query.is_nested() || key.is_nested() || value.is_nested();
566+
}
567+
564568
} // namespace
565569

566570
// Computes scaled dot product attention on query, key and value tensors, using
@@ -617,6 +621,11 @@ Tensor scaled_dot_product_attention(
617621
Tensor value_padded = pad_last_dim<8, false>(value);
618622
// We need to calculate the scale based off the OG head dim size
619623
auto og_scale = sdp::calculate_scale(query_, scale);
624+
if (has_nested_inputs(query_padded, key_padded, value_padded)) {
625+
auto out_lse_softmax = at::_scaled_dot_product_flash_attention_nested(
626+
query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
627+
return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
628+
}
620629
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
621630
query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
622631
return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
@@ -715,8 +724,6 @@ std::tuple<
715724
at::Tensor,
716725
at::Tensor,
717726
at::Tensor,
718-
int64_t,
719-
int64_t,
720727
at::Tensor,
721728
at::Tensor,
722729
at::Tensor>
@@ -751,21 +758,19 @@ _scaled_dot_product_flash_attention_cpu(
751758
query.options().dtype(accumulate_dtype));
752759
at::Tensor cum_seq_q = at::empty({}, at::kLong);
753760
at::Tensor cum_seq_k = at::empty({}, at::kLong);
754-
int64_t max_q = 0;
755-
int64_t max_k = 0;
756761
at::Tensor philox_seed = at::empty({}, at::kLong);
757762
at::Tensor philox_offset = at::empty({}, at::kLong);
758763
at::Tensor debug_attn_mask = at::empty({}, query.options());
759764

760765
flash_attention_kernel(kCPU, output, logsumexp, cum_seq_q, cum_seq_k,
761-
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
766+
philox_seed, philox_offset, debug_attn_mask,
762767
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
763768

764769
output = output.transpose(1, 2);
765770
logsumexp = logsumexp.transpose(1, 2);
766771

767772
return std::make_tuple(std::move(output), std::move(logsumexp),
768-
std::move(cum_seq_q), std::move(cum_seq_k), max_q, max_k,
773+
std::move(cum_seq_q), std::move(cum_seq_k),
769774
std::move(philox_seed), std::move(philox_offset), std::move(debug_attn_mask));
770775
}
771776

@@ -802,7 +807,7 @@ _scaled_dot_product_flash_attention_backward_cpu(
802807

803808
flash_attention_backward_kernel(kCPU, grad_q, grad_k, grad_v,
804809
grad_out_t, q_t, k_t, v_t, o_t, lse_t, cum_seq_q, cum_seq_k,
805-
max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale);
810+
dropout_p, is_causal, philox_seed, philox_offset, scale);
806811

807812
grad_q = grad_q.transpose(1, 2);
808813
grad_k = grad_k.transpose(1, 2);

aten/src/ATen/native/transformers/attention.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ TORCH_API Tensor qkv_projection(
5252
using flash_attention_fn = void (*)(
5353
const Tensor& output, const Tensor& logsumexp,
5454
const Tensor& cum_seq_q, const Tensor& cum_seq_k,
55-
int64_t& max_q, int64_t& max_k, const Tensor& philox_seed,
55+
const Tensor& philox_seed,
5656
const Tensor& philox_offset, const Tensor& debug_attn_mask,
5757
const Tensor& query, const Tensor& key, const Tensor& value,
5858
double dropout_p, bool is_causal, bool return_debug_mask,
@@ -64,7 +64,6 @@ using flash_attention_backward_fn = void (*)(
6464
const Tensor& query, const Tensor& key,
6565
const Tensor& value, const Tensor& out, const Tensor& logsumexp,
6666
const Tensor& cum_seq_q, const Tensor& cum_seq_k,
67-
const int64_t max_q, const int64_t max_k,
6867
double dropout_p, bool is_causal,
6968
const Tensor& philox_seed, const Tensor& philox_offset,
7069
c10::optional<double> scale);

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
642642
}
643643
return std::make_tuple(std::move(proj), std::move(qkt));
644644
}
645-
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
645+
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
646646
const Tensor& query,
647647
const Tensor& key,
648648
const Tensor& value,
@@ -691,7 +691,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor, Tensor, Ten
691691
// Reshape output to convert nnz to batch_size and seq_len
692692
Tensor attention = output.transpose(1,2);
693693

694-
return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
694+
return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), philox_seed, philox_offset, debug_attn_mask);
695695
}
696696

697697
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
@@ -828,11 +828,11 @@ _flash_attention_forward(
828828
debug_attn_mask =
829829
return_debug_mask ? debug_attn_mask : at::empty({0}, query.options());
830830
return std::make_tuple(
831-
output,
832-
logsumexp,
833-
philox_seed,
834-
philox_offset,
835-
debug_attn_mask);
831+
std::move(output),
832+
std::move(logsumexp),
833+
std::move(philox_seed),
834+
std::move(philox_offset),
835+
std::move(debug_attn_mask));
836836

837837
#endif
838838
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")

test/test_transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,7 @@ def ones_tensor(*shape):
11891189
_ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
11901190
torch.cuda.synchronize()
11911191

1192+
@slowTest
11921193
@unittest.skipIf(
11931194
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
11941195
)
@@ -2496,6 +2497,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
24962497
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
24972498
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
24982499

2500+
@slowTest
24992501
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
25002502
@parametrize("batch_size", [1, 8])
25012503
@parametrize("seq_len_q", [256, 512, 1024])

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2764,9 +2764,9 @@
27642764
output_differentiability: [True, False, False, False]
27652765
query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)
27662766

2767-
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
2768-
output_differentiability: [True, False, False, False, False, False, False, False, False]
2769-
query, key, value: _scaled_dot_product_flash_attention_backward(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
2767+
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
2768+
output_differentiability: [True, False, False, False, False, False, False]
2769+
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, query.sym_size(2), key.sym_size(2), dropout_p, is_causal, philox_seed, philox_offset, scale)
27702770

27712771
# - name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor query_padded, Tensor key_padded, Tensor value_padded, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
27722772
# output_differentiability: [True, False, False, False, False, False, False, False]

torch/_decomp/decompositions.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4093,7 +4093,7 @@ def scaled_dot_product_flash_attention(
40934093
return_debug_mask: bool = False,
40944094
*,
40954095
scale: Optional[float] = None,
4096-
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor]:
4096+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
40974097
dtype = query.dtype
40984098
batchSize, num_head, qSize, headSize = (
40994099
query.shape[0],
@@ -4123,7 +4123,6 @@ def scaled_dot_product_flash_attention(
41234123
cum_seq_q, cum_seq_k = torch.empty([], dtype=torch.long), torch.empty(
41244124
[], dtype=torch.long
41254125
)
4126-
max_q, max_k = 0, 0
41274126
philox_seed, philox_offset = torch.empty([], dtype=torch.long), torch.empty(
41284127
[], dtype=torch.long
41294128
)
@@ -4175,8 +4174,6 @@ def scaled_dot_product_flash_attention(
41754174
logsumexp,
41764175
cum_seq_q,
41774176
cum_seq_k,
4178-
max_q,
4179-
max_k,
41804177
philox_seed,
41814178
philox_offset,
41824179
debug_attn_mask,

torch/_meta_registrations.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4900,8 +4900,6 @@ def meta__scaled_dot_product_flash(
49004900
logsumexp,
49014901
torch.empty((), dtype=torch.int32, device="meta"),
49024902
torch.empty((), dtype=torch.int32, device="meta"),
4903-
0,
4904-
0,
49054903
torch.empty((), dtype=torch.long, device="meta"),
49064904
torch.empty((), dtype=torch.long, device="meta"),
49074905
torch.empty((), dtype=query.dtype, device=query.device),
@@ -4941,8 +4939,6 @@ def meta__scaled_dot_product_flash(
49414939
logsumexp,
49424940
None,
49434941
None,
4944-
max_seqlen_batch_q,
4945-
max_seqlen_batch_k,
49464942
torch.empty((), dtype=torch.long, device="meta"),
49474943
torch.empty((), dtype=torch.long, device="meta"),
49484944
debug_mask,

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
164164
AtenTensorHandle* ret1, // returns new reference
165165
AtenTensorHandle* ret2, // returns new reference
166166
AtenTensorHandle* ret3, // returns new reference
167-
int64_t* ret4,
168-
int64_t* ret5,
169-
AtenTensorHandle* ret6, // returns new reference
170-
AtenTensorHandle* ret7, // returns new reference
171-
AtenTensorHandle* ret8 // returns new reference
167+
AtenTensorHandle* ret4, // returns new reference
168+
AtenTensorHandle* ret5, // returns new reference
169+
AtenTensorHandle* ret6 // returns new reference
172170
);
173171

174172
// This function will create a new uninitialized tensor object

0 commit comments

Comments
 (0)