@@ -55,8 +55,6 @@ void cpu_flash_attention(
5555 const Tensor& logsumexp,
5656 const Tensor& cum_seq_q,
5757 const Tensor& cum_seq_k,
58- int64_t & max_q,
59- int64_t & max_k,
6058 const Tensor& philox_seed,
6159 const Tensor& philox_offset,
6260 const Tensor& debug_attn_mask,
@@ -279,8 +277,6 @@ void cpu_flash_attention_backward(
279277 const at::Tensor& logsumexp,
280278 const Tensor& cumulative_sequence_length_q,
281279 const Tensor& cumulative_sequence_length_k,
282- const int64_t max_seqlen_batch_q,
283- const int64_t max_seqlen_batch_k,
284280 double dropout_p,
285281 bool is_causal,
286282 const at::Tensor& philox_seed,
@@ -540,8 +536,6 @@ void flash_attention_kernel_impl(
540536 const Tensor& logsumexp,
541537 const Tensor& cum_seq_q,
542538 const Tensor& cum_seq_k,
543- int64_t & max_q,
544- int64_t & max_k,
545539 const Tensor& philox_seed,
546540 const Tensor& philox_offset,
547541 const Tensor& debug_attn_mask,
@@ -558,17 +552,17 @@ void flash_attention_kernel_impl(
558552 if (q_seq_len >= 768 ) {
559553 cpu_flash_attention<scalar_t , 256 , 512 >(
560554 output, logsumexp, cum_seq_q, cum_seq_k,
561- max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
555+ philox_seed, philox_offset, debug_attn_mask,
562556 query, key, value, dropout_p, is_causal, return_debug_mask, scale);
563557 } else if (q_seq_len >= 192 ) {
564558 cpu_flash_attention<scalar_t , 64 , 512 >(
565559 output, logsumexp, cum_seq_q, cum_seq_k,
566- max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
560+ philox_seed, philox_offset, debug_attn_mask,
567561 query, key, value, dropout_p, is_causal, return_debug_mask, scale);
568562 } else {
569563 cpu_flash_attention<scalar_t , 32 , 512 >(
570564 output, logsumexp, cum_seq_q, cum_seq_k,
571- max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
565+ philox_seed, philox_offset, debug_attn_mask,
572566 query, key, value, dropout_p, is_causal, return_debug_mask, scale);
573567 }
574568 });
@@ -586,8 +580,6 @@ void flash_attention_backward_kernel_impl(
586580 const at::Tensor& logsumexp,
587581 const Tensor& cum_seq_q,
588582 const Tensor& cum_seq_k,
589- const int64_t max_q,
590- const int64_t max_k,
591583 double dropout_p,
592584 bool is_causal,
593585 const at::Tensor& philox_seed,
@@ -604,19 +596,19 @@ void flash_attention_backward_kernel_impl(
604596 cpu_flash_attention_backward<scalar_t , 256 , 512 >(
605597 grad_q, grad_k, grad_v, grad_out_contig,
606598 query, key, value, out, logsumexp,
607- cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
599+ cum_seq_q, cum_seq_k, dropout_p,
608600 is_causal, philox_seed, philox_offset, scale);
609601 } else if (q_seq_len >= 192 ) {
610602 cpu_flash_attention_backward<scalar_t , 64 , 512 >(
611603 grad_q, grad_k, grad_v, grad_out_contig,
612604 query, key, value, out, logsumexp,
613- cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
605+ cum_seq_q, cum_seq_k, dropout_p,
614606 is_causal, philox_seed, philox_offset, scale);
615607 } else {
616608 cpu_flash_attention_backward<scalar_t , 32 , 512 >(
617609 grad_q, grad_k, grad_v, grad_out_contig,
618610 query, key, value, out, logsumexp,
619- cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
611+ cum_seq_q, cum_seq_k, dropout_p,
620612 is_causal, philox_seed, philox_offset, scale);
621613 }
622614 });
0 commit comments