@@ -234,18 +234,17 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_nested(
234234 return std::make_tuple (Tensor (), Tensor ());
235235 }
236236}
237- namespace {
237+
238238
239239/* *
240240 * This function is used to calculate two pieces of metadata that are needed
241241 * for use with flash-attention and efficient_attention kernels. They are the
242242 * cumulative sequence_length over a batch of sequences and the maximum sequence
243243 * length.
244244 *
245- * @return A tuple of cumulative sequence lengths and the maximum sequence length,
246- * and the last element in the cumulative_sequence_lengths
245+ * @return A tuple of cumulative sequence lengths and the maximum sequence length
247246 */
248- std::tuple<Tensor, int64_t , int64_t > cumulative_and_max_seq_len (Tensor qkv) {
247+ std::tuple<Tensor, int64_t > cumulative_and_max_seq_len (Tensor qkv) {
249248 TORCH_CHECK (
250249 qkv.is_nested (),
251250 " QKV must be nested for flash cumulative_seq_len calculation." )
@@ -275,7 +274,7 @@ std::tuple<Tensor, int64_t, int64_t> cumulative_and_max_seq_len(Tensor qkv) {
275274 // Send to GPU, this is pretty light weight calc for normal batch size
276275 // but maybe this needs to be on gpu
277276 cumulative_seqlen = cumulative_seqlen.to (TensorOptions ().device (at::kCUDA ));
278- return std::tuple<Tensor, int64_t , int64_t >{cumulative_seqlen, max_seqlen, sum };
277+ return std::tuple<Tensor, int64_t >{cumulative_seqlen, max_seqlen};
279278}
280279
281280/* *
@@ -338,7 +337,6 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {
338337 return true ;
339338}
340339
341- } // namespace
342340std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked (
343341 const Tensor& query,
344342 const Tensor& key,
@@ -356,19 +354,19 @@ std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
356354 Tensor k_t = key.transpose (1 , 2 );
357355 Tensor v_t = value.transpose (1 , 2 );
358356
359- auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len (q_t );
360- auto cumulative_and_max_k_and_nnz_k = cumulative_and_max_seq_len (k_t );
357+ auto cumulative_and_max_q = cumulative_and_max_seq_len (q_t );
358+ auto cumulative_and_max_k = cumulative_and_max_seq_len (k_t );
361359
362360 // K and V have to have the same Nnz, should probably torch_check
363361 // assume in order to not iterate over v
364362
365- Tensor cumulative_sequence_length_q = std::get<0 >(cumulative_and_max_q_and_nnz_q );
366- Tensor cumulative_sequence_length_k = std::get<0 >(cumulative_and_max_k_and_nnz_k );
363+ Tensor cumulative_sequence_length_q = std::get<0 >(cumulative_and_max_q );
364+ Tensor cumulative_sequence_length_k = std::get<0 >(cumulative_and_max_k );
367365
368- const int64_t max_seqlen_batch_q = std::get<1 >(cumulative_and_max_q_and_nnz_q );
366+ const int64_t max_seqlen_batch_q = std::get<1 >(cumulative_and_max_q );
369367
370- const int64_t Nnz_q = std::get< 2 >(cumulative_and_max_q_and_nnz_q );
371- const int64_t Nnz_kv = std::get< 2 >(cumulative_and_max_k_and_nnz_k );
368+ const int64_t Nnz_q = cumulative_sequence_length_q[- 1 ]. item < int64_t >( );
369+ const int64_t Nnz_kv = cumulative_sequence_length_k[- 1 ]. item < int64_t >( );
372370
373371 Tensor query_buffer_reshaped;
374372 Tensor key_buffer_reshaped;
@@ -462,15 +460,15 @@ Tensor flash_attention_helper(
462460 int64_t head_dim{query.size (-1 )};
463461 int64_t num_heads{query.size (-2 )};
464462
465- auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len (query);
466- Tensor cumulative_sequence_length_q = std::get<0 >(cumulative_and_max_q_and_nnz_q );
467- int64_t max_seqlen_batch_q = std::get<1 >(cumulative_and_max_q_and_nnz_q );
463+ auto cumulative_and_max_q = cumulative_and_max_seq_len (query);
464+ Tensor cumulative_sequence_length_q = std::get<0 >(cumulative_and_max_q );
465+ int64_t max_seqlen_batch_q = std::get<1 >(cumulative_and_max_q );
468466
469467 TORCH_CHECK (
470468 key.is_same (key) && query.is_same (value),
471469 " Key and Value must be the same tensor" );
472470
473- int64_t Nnz_q = std::get< 2 >(cumulative_and_max_q_and_nnz_q) ;
471+ int64_t Nnz_q{cumulative_sequence_length_q[- 1 ]. item < int64_t >()} ;
474472
475473 // For the packed case we need to set the output size for dim 2 to 1
476474 auto atten_size = get_nested_size_tensor (query).clone ();
0 commit comments