@@ -86,7 +86,10 @@ class RunnerBase
8686 c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
8787 torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
8888 torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
89- torch::optional<torch::Tensor> sparse_attn_offsets) const
89+ torch::optional<torch::Tensor> sparse_attn_offsets, std::optional<torch::Tensor> cu_q_seqlens,
90+ std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
91+ std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
92+ std::optional<torch::Tensor> quant_q_buffer) const
9093 = 0;
9194};
9295
@@ -143,10 +146,14 @@ class Runner : public RunnerBase
143146 c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
144147 torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
145148 torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
146- torch::optional<torch::Tensor> sparse_attn_offsets) const override
149+ torch::optional<torch::Tensor> sparse_attn_offsets, std::optional<torch::Tensor> cu_q_seqlens,
150+ std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
151+ std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
152+ std::optional<torch::Tensor> quant_q_buffer) const override
147153 {
148154 auto stream = at::cuda::getCurrentCUDAStream (qkv_or_q.get_device ());
149155 T* attention_input = static_cast <T*>(qkv_or_q.slice (0 , token_offset).data_ptr ());
156+
150157 T* k_ptr = nullptr ;
151158 T* v_ptr = nullptr ;
152159 AttentionOutT* context_buf = static_cast <AttentionOutT*>(output.slice (0 , token_offset).data_ptr ());
@@ -209,6 +216,22 @@ class Runner : public RunnerBase
209216 mla_params.q_pe = static_cast <T*>(q_pe->data_ptr ());
210217 mla_params.q_pe_ld = q_pe->strides ()[1 ];
211218 mla_params.q_pe_stride = q_pe->strides ()[0 ];
219+
220+ mla_params.seqQOffset
221+ = cu_q_seqlens.has_value () ? reinterpret_cast <int *>(cu_q_seqlens.value ().data_ptr ()) : nullptr ;
222+ mla_params.cu_kv_seqlens
223+ = cu_kv_seqlens.has_value () ? reinterpret_cast <int *>(cu_kv_seqlens.value ().data_ptr ()) : nullptr ;
224+ mla_params.fmha_tile_counter = fmha_scheduler_counter.has_value ()
225+ ? reinterpret_cast <uint32_t *>(fmha_scheduler_counter.value ().data_ptr ())
226+ : nullptr ;
227+ mla_params.bmm1_scale = mla_bmm1_scale.has_value ()
228+ ? reinterpret_cast <float *>(mla_bmm1_scale.value ().data_ptr ())
229+ : nullptr ;
230+ mla_params.bmm2_scale = mla_bmm2_scale.has_value ()
231+ ? reinterpret_cast <float *>(mla_bmm2_scale.value ().data_ptr ())
232+ : nullptr ;
233+ mla_params.quant_q_buf
234+ = quant_q_buffer.has_value () ? reinterpret_cast <void *>(quant_q_buffer.value ().data_ptr ()) : nullptr ;
212235 }
213236 mla_params.q_buf = attention_input;
214237 mla_params.context_buf = reinterpret_cast <T*>(context_buf);
@@ -541,7 +564,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
541564 std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t > attention_chunk_size,
542565 std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool > spec_decoding_bool_params,
543566 std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
544- std::vector<std::optional<torch::Tensor>> sparse_attention_params)
567+ std::vector<std::optional<torch::Tensor>> sparse_attention_params, std::optional<torch::Tensor> cu_q_seqlens,
568+ std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
569+ std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
570+ std::optional<torch::Tensor> quant_q_buffer)
545571{
546572 // Decompress sparse attention parameters
547573 TORCH_CHECK (sparse_attention_params.size () == 4 , " Expected 4 sparse attention parameters" );
@@ -569,6 +595,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
569595 TLLM_CHECK_WITH_INFO (v.has_value (), " The v tensor should be provided if updating KV cache with unfused K/V" );
570596 }
571597
598+ // 2. 数据类型检测和Runner创建
572599 auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType (qkv_or_q.scalar_type ());
573600 bool const is_fp8_out = out_dtype.has_value () && out_dtype.value () == torch::kFloat8_e4m3fn ;
574601 bool const is_fp4_out = out_dtype.has_value () && out_dtype.value () == torch::kUInt8 ;
@@ -624,6 +651,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
624651 int64_t const rotary_embedding_max_positions = rotary_embedding_max_position_info[0 ];
625652 int64_t const rotary_embedding_original_max_positions = rotary_embedding_max_position_info[1 ];
626653
654+ // 3. AttentionOp创建和初始化
627655 auto op = std::make_shared<AttentionOp>();
628656 op->mType = dtype;
629657 op->mFMHAForceFP32Acc = dtype == nvinfer1::DataType::kBF16 ;
@@ -709,6 +737,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
709737 = chunked_prefill_buffer_batch_size.has_value () ? chunked_prefill_buffer_batch_size.value () : 1 ;
710738 }
711739
740+ // 4. 缓存检查和初始化
712741 auto cache_key = std::make_tuple (op->data (), runner->data ());
713742 using CacheKey = decltype (cache_key);
714743 static std::unordered_map<CacheKey, std::shared_ptr<AttentionOp>, hash<CacheKey>> op_cache;
@@ -726,6 +755,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
726755 op_cache[cache_key] = op;
727756 }
728757
758+ // 5. 请求类型和输入类型检测:ctx, gen: for continuous batching
729759 int32_t const num_seqs = host_context_lengths.size (0 );
730760 RequestType const * request_types = static_cast <RequestType const *>(host_request_types.data_ptr ());
731761
@@ -758,6 +788,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
758788 TLLM_CHECK (request_types[idx] == RequestType::kGENERATION );
759789 }
760790
791+ // 6. 工作空间管理
761792 int32_t const max_attention_window_size
762793 = beam_width == 1 ? attention_window_size : cache_indirection.value ().size (2 );
763794 int32_t const max_blocks_per_sequence
@@ -805,7 +836,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
805836 host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
806837 rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
807838 mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
808- attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets);
839+ attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
840+ cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
809841 }
810842
811843 if ((num_generations > 0 ) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -822,7 +854,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
822854 host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
823855 rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
824856 mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
825- attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets);
857+ attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
858+ cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
826859 }
827860
828861 TLLM_LOG_TRACE (" Attention op stops at layer %d" , layer_idx);
0 commit comments