Skip to content

Commit 5e4f5de

Browse files
committed
accuracy on unnittest with 1 layer, bf16 + fp8
Signed-off-by: yunruis <205571022+yunruis@users.noreply.github.com> accuracy right on unittest Signed-off-by: yunruis <205571022+yunruis@users.noreply.github.com> ut and quickstart advanced pass, with full comment Signed-off-by: yunruis <205571022+yunruis@users.noreply.github.com> drop print info Signed-off-by: yunruis <205571022+yunruis@users.noreply.github.com>
1 parent ec32711 commit 5e4f5de

File tree

11 files changed

+1061
-78
lines changed

11 files changed

+1061
-78
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -977,33 +977,8 @@ int AttentionOp::mlaGeneration(
977977
// Workspace pointer shift
978978
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(params.workspace);
979979
size_t offset = 0;
980-
981-
size_t const cu_seqlens_size = sizeof(int) * (params.batch_size + 1);
982-
size_t const fmha_scheduler_counter = sizeof(uint32_t);
983-
size_t const mla_bmm1_scale_size = mFP8GenerationMLA ? sizeof(float) * 2 : 0;
984-
size_t const mla_bmm2_scale_size = mFP8GenerationMLA ? sizeof(float) : 0;
985-
size_t const quant_q_buffer_size = mFP8GenerationMLA
986-
? params.acc_q_len * size_t(mNumHeads * (mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim))
987-
: 0;
988-
int* cu_q_seqlens = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, cu_seqlens_size));
989-
int* cu_kv_seqlens = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, cu_seqlens_size));
990-
uint32_t* fmha_tile_counter_ptr
991-
= reinterpret_cast<uint32_t*>(nextWorkspacePtr(workspace_byte_ptr, offset, fmha_scheduler_counter));
992-
float* mla_bmm1_scale_ptr
993-
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, mla_bmm1_scale_size));
994-
float* mla_bmm2_scale_ptr
995-
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, mla_bmm2_scale_size));
996-
void* quant_q_buffer_ptr
997-
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, quant_q_buffer_size));
998980
void* scratch_ptr = nextWorkspacePtr(workspace_byte_ptr, offset);
999981

1000-
params.seqQOffset = cu_q_seqlens;
1001-
params.cu_kv_seqlens = cu_kv_seqlens;
1002-
params.fmha_tile_counter = fmha_tile_counter_ptr;
1003-
params.bmm1_scale = mla_bmm1_scale_ptr;
1004-
params.bmm2_scale = mla_bmm2_scale_ptr;
1005-
params.quant_q_buf = quant_q_buffer_ptr;
1006-
1007982
params.quant_scale_o = generation_params.attention_output_orig_quant;
1008983
params.quant_scale_q = generation_params.kv_scale_orig_quant;
1009984
params.quant_scale_kv = generation_params.kv_scale_orig_quant;
@@ -1012,9 +987,6 @@ int AttentionOp::mlaGeneration(
1012987
params.host_bmm1_scale
1013988
= 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
1014989

1015-
invokeMLARopeGeneration<T>(params, kv_cache_buffer, stream);
1016-
sync_check_cuda_error(stream);
1017-
1018990
if (generation_params.runtime_perf_knobs)
1019991
{
1020992
int64_t multi_block_mode_val = generation_params.runtime_perf_knobs[0];
@@ -1245,7 +1217,7 @@ int AttentionOp::mlaGeneration(
12451217
XQAParams xqaParams{};
12461218
this->template convertMMHAParamsToXQAParams<T, decltype(kv_cache_buffer)>(
12471219
xqaParams, generation_params, /*forConfigurePlugin=*/false);
1248-
xqaParams.quant_q_buffer_ptr = quant_q_buffer_ptr;
1220+
xqaParams.quant_q_buffer_ptr = params.quant_q_buf;
12491221
xqaParams.q_scaling
12501222
= 1 / (mQScaling * sqrtf((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
12511223
if (mEnableXQA && mXqaDispatcher->shouldUse(xqaParams))
@@ -1287,11 +1259,11 @@ int AttentionOp::mlaGeneration(
12871259

12881260
// fmhaParams.packedMaskPtr = params.fmha_custom_mask;
12891261
fmhaParams.pagedKvCache = kv_cache_buffer;
1290-
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
1262+
fmhaParams.cuQSeqLenPtr = params.seqQOffset;
12911263
fmhaParams.kvSeqLenPtr = params.cache_seq_lens;
1292-
fmhaParams.cuKvSeqLenPtr = cu_kv_seqlens;
1264+
fmhaParams.cuKvSeqLenPtr = params.cu_kv_seqlens;
12931265
fmhaParams.cuMaskRowsPtr = nullptr; // mla not support custorm mask right now
1294-
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
1266+
fmhaParams.tileCounterPtr = params.fmha_tile_counter;
12951267
fmhaParams.scaleBmm1Ptr = reinterpret_cast<float const*>(params.bmm1_scale);
12961268
fmhaParams.scaleBmm2Ptr = reinterpret_cast<float const*>(params.bmm2_scale);
12971269
fmhaParams.stream = stream;
@@ -1608,7 +1580,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16081580
// 1. only apply to self attention. If want fused multi-head cross attention, FMHCA kernels and runner is needed
16091581
// 2. doesn't apply to MHA with relative attention bias, i.e. softmax(QK + bias) * V
16101582
// We update mEnableContextFMHA in constructor to check these conditions
1611-
if (mEnableContextFMHA)
1583+
if (mEnableContextFMHA) // fused
16121584
{
16131585
// do all-to-all for params.attention_input, need to split on kv head
16141586
// [token_num // cp_size, kv_heads, head_size] -> [token_num, kv_heads // cp_size, head_size]

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,33 @@ void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, c
932932
params.cache_type, params.quant_scale_kv);
933933
}
934934

935+
__global__ void printCudaVectorInt32(int32_t const* vec, int32_t size)
936+
{
937+
for (int i = 0; i < size; i++)
938+
{
939+
printf("%d, ", vec[i]);
940+
}
941+
printf("\n");
942+
}
943+
944+
__global__ void printCudaVectorUint32(uint32_t const* vec, int32_t size)
945+
{
946+
for (int i = 0; i < size; i++)
947+
{
948+
printf("%u, ", vec[i]);
949+
}
950+
printf("\n");
951+
}
952+
953+
__global__ void printCudaVectorFloat(float const* vec, int32_t size)
954+
{
955+
for (int i = 0; i < size; i++)
956+
{
957+
printf("%f, ", vec[i]);
958+
}
959+
printf("\n");
960+
}
961+
935962
template <typename T>
936963
void invokeMLAContextFp8Quantize(MlaParams<T>& params, int total_kv_len, cudaStream_t stream)
937964
{
@@ -989,12 +1016,84 @@ void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer
9891016
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
9901017
config.numAttrs = 1;
9911018
config.attrs = attrs;
1019+
// printf("=================invokeMLARopeGeneration============\n");
1020+
// printf("head_num: %zu\n", params.head_num);
1021+
// printf("kv_lora_rank: %d\n", params.meta.kv_lora_rank);
1022+
// printf("acc_q_len: %d\n", params.acc_q_len);
1023+
// printf("seq_len: %d\n", seq_len);
1024+
// printf("q_pe_ld: %d\n", params.q_pe_ld);
1025+
// printf("q_pe_stride: %d\n", params.q_pe_stride);
1026+
// printf("cache_type: %d\n", static_cast<int>(params.cache_type));
1027+
// printf("host_bmm1_scale: %f\n", params.host_bmm1_scale);
1028+
// // 需要打印一些cuda 的vector变量
1029+
// printf("cache_seq_lens: ");
1030+
// printCudaVectorInt32<<<1, 1, 0, stream>>>(params.cache_seq_lens, params.batch_size);
1031+
// cudaDeviceSynchronize();
1032+
1033+
// if (params.quant_scale_o)
1034+
// {
1035+
// printf("quant_scale_o: ");
1036+
// printCudaVectorFloat<<<1, 1, 0, stream>>>(params.quant_scale_o, 1);
1037+
// cudaDeviceSynchronize();
1038+
// }
1039+
1040+
// if (params.quant_scale_q)
1041+
// {
1042+
// printf("quant_scale_q: ");
1043+
// printCudaVectorFloat<<<1, 1, 0, stream>>>(params.quant_scale_q, 1);
1044+
// cudaDeviceSynchronize();
1045+
// }
1046+
// if (params.quant_scale_kv)
1047+
// {
1048+
// printf("quant_scale_kv: ");
1049+
// printCudaVectorFloat<<<1, 1, 0, stream>>>(params.quant_scale_kv, 1);
1050+
// cudaDeviceSynchronize();
1051+
// }
1052+
1053+
// if (params.bmm1_scale)
1054+
// {
1055+
// printf("bmm1_scale: ");
1056+
// printCudaVectorFloat<<<1, 1, 0, stream>>>(params.bmm1_scale, 2);
1057+
// cudaDeviceSynchronize();
1058+
// }
1059+
// if (params.bmm2_scale)
1060+
// {
1061+
// printf("bmm2_scale: ");
1062+
// printCudaVectorFloat<<<1, 1, 0, stream>>>(params.bmm2_scale, 1);
1063+
// cudaDeviceSynchronize();
1064+
// }
1065+
9921066
cudaLaunchKernelEx(&config, kernel_instance, params.q_buf, params.q_pe, params.latent_cache, params.quant_q_buf,
9931067
kv_cache_buffer, params.cos_sin_cache, params.head_num, params.meta.kv_lora_rank, params.acc_q_len, seq_len,
9941068
params.seqQOffset, params.fmha_tile_counter, params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld,
9951069
params.q_pe_stride, params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o,
9961070
params.quant_scale_q, params.quant_scale_kv, params.dequant_scale_q, params.dequant_scale_kv,
9971071
params.host_bmm1_scale, params.helix_position_offsets);
1072+
1073+
// cudaDeviceSynchronize();
1074+
// printf("Output\n");
1075+
// printf("seqQOffset: ");
1076+
// printCudaVectorInt32<<<1, 1, 0, stream>>>(params.seqQOffset, params.batch_size + 1);
1077+
// cudaDeviceSynchronize();
1078+
// printf("seqKVOffsets: ");
1079+
// printCudaVectorInt32<<<1, 1, 0, stream>>>(params.cu_kv_seqlens, params.batch_size + 1);
1080+
// cudaDeviceSynchronize();
1081+
// printf("fmha_tile_counter: ");
1082+
// printCudaVectorUint32<<<1, 1, 0, stream>>>(params.fmha_tile_counter, 1);
1083+
// cudaDeviceSynchronize();
1084+
// if (params.bmm1_scale)
1085+
// {
1086+
// printf("bmm1_scale: ");
1087+
// printCudaVectorFloat<<<1, 1, 0, stream>>>(params.bmm1_scale, 2);
1088+
// cudaDeviceSynchronize();
1089+
// }
1090+
// if (params.bmm2_scale)
1091+
// {
1092+
// printf("bmm2_scale: ");
1093+
// printCudaVectorFloat<<<1, 1, 0, stream>>>(params.bmm2_scale, 1);
1094+
// cudaDeviceSynchronize();
1095+
// }
1096+
// printf("====================\n");
9981097
}
9991098

10001099
template <typename T, typename TCache>

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ void initBindings(nb::module_& m)
5656
nb::arg("mla_tensor_params"), nb::arg("attention_chunk_size") = std::nullopt,
5757
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
5858
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_attention_params"), "Multi-head attention operation",
59+
nb::arg("cu_q_seqlens") = std::nullopt, nb::arg("cu_kv_seqlens") = std::nullopt,
60+
nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt,
61+
nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt,
5962
nb::call_guard<nb::gil_scoped_release>());
6063
}
6164
} // namespace tensorrt_llm::nanobind::thop

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ void initBindings(pybind11::module_& m)
5656
py::arg("mla_tensor_params"), py::arg("attention_chunk_size") = std::nullopt,
5757
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
5858
py::arg("spec_decoding_tensor_params"), py::arg("sparse_attention_params"), "Multi-head attention operation",
59-
py::call_guard<py::gil_scoped_release>());
59+
py::arg("cu_q_seqlens") = std::nullopt, py::arg("cu_kv_seqlens") = std::nullopt,
60+
py::arg("fmha_scheduler_counter") = std::nullopt, py::arg("mla_bmm1_scale") = std::nullopt,
61+
py::arg("mla_bmm2_scale") = std::nullopt, py::arg("quant_q_buffer") = std::nullopt,
62+
"Multi-head attention operation", py::call_guard<py::gil_scoped_release>());
6063
}
6164
} // namespace tensorrt_llm::pybind::thop

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ add_library(
9595
mtpOp.cpp
9696
loraOp.cpp
9797
finegrained_mixed_dtype_gemm_thop.cpp
98-
tinygemm2.cpp)
98+
tinygemm2.cpp
99+
dsv3RopeOp.cpp)
99100
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
100101
target_link_libraries(
101102
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ class RunnerBase
8686
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
8787
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
8888
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
89-
torch::optional<torch::Tensor> sparse_attn_offsets) const
89+
torch::optional<torch::Tensor> sparse_attn_offsets, std::optional<torch::Tensor> cu_q_seqlens,
90+
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
91+
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
92+
std::optional<torch::Tensor> quant_q_buffer) const
9093
= 0;
9194
};
9295

@@ -143,10 +146,14 @@ class Runner : public RunnerBase
143146
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
144147
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
145148
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
146-
torch::optional<torch::Tensor> sparse_attn_offsets) const override
149+
torch::optional<torch::Tensor> sparse_attn_offsets, std::optional<torch::Tensor> cu_q_seqlens,
150+
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
151+
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
152+
std::optional<torch::Tensor> quant_q_buffer) const override
147153
{
148154
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
149155
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
156+
150157
T* k_ptr = nullptr;
151158
T* v_ptr = nullptr;
152159
AttentionOutT* context_buf = static_cast<AttentionOutT*>(output.slice(0, token_offset).data_ptr());
@@ -209,6 +216,22 @@ class Runner : public RunnerBase
209216
mla_params.q_pe = static_cast<T*>(q_pe->data_ptr());
210217
mla_params.q_pe_ld = q_pe->strides()[1];
211218
mla_params.q_pe_stride = q_pe->strides()[0];
219+
220+
mla_params.seqQOffset
221+
= cu_q_seqlens.has_value() ? reinterpret_cast<int*>(cu_q_seqlens.value().data_ptr()) : nullptr;
222+
mla_params.cu_kv_seqlens
223+
= cu_kv_seqlens.has_value() ? reinterpret_cast<int*>(cu_kv_seqlens.value().data_ptr()) : nullptr;
224+
mla_params.fmha_tile_counter = fmha_scheduler_counter.has_value()
225+
? reinterpret_cast<uint32_t*>(fmha_scheduler_counter.value().data_ptr())
226+
: nullptr;
227+
mla_params.bmm1_scale = mla_bmm1_scale.has_value()
228+
? reinterpret_cast<float*>(mla_bmm1_scale.value().data_ptr())
229+
: nullptr;
230+
mla_params.bmm2_scale = mla_bmm2_scale.has_value()
231+
? reinterpret_cast<float*>(mla_bmm2_scale.value().data_ptr())
232+
: nullptr;
233+
mla_params.quant_q_buf
234+
= quant_q_buffer.has_value() ? reinterpret_cast<void*>(quant_q_buffer.value().data_ptr()) : nullptr;
212235
}
213236
mla_params.q_buf = attention_input;
214237
mla_params.context_buf = reinterpret_cast<T*>(context_buf);
@@ -541,7 +564,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
541564
std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t> attention_chunk_size,
542565
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
543566
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
544-
std::vector<std::optional<torch::Tensor>> sparse_attention_params)
567+
std::vector<std::optional<torch::Tensor>> sparse_attention_params, std::optional<torch::Tensor> cu_q_seqlens,
568+
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
569+
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
570+
std::optional<torch::Tensor> quant_q_buffer)
545571
{
546572
// Decompress sparse attention parameters
547573
TORCH_CHECK(sparse_attention_params.size() == 4, "Expected 4 sparse attention parameters");
@@ -569,6 +595,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
569595
TLLM_CHECK_WITH_INFO(v.has_value(), "The v tensor should be provided if updating KV cache with unfused K/V");
570596
}
571597

598+
// 2. 数据类型检测和Runner创建
572599
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(qkv_or_q.scalar_type());
573600
bool const is_fp8_out = out_dtype.has_value() && out_dtype.value() == torch::kFloat8_e4m3fn;
574601
bool const is_fp4_out = out_dtype.has_value() && out_dtype.value() == torch::kUInt8;
@@ -624,6 +651,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
624651
int64_t const rotary_embedding_max_positions = rotary_embedding_max_position_info[0];
625652
int64_t const rotary_embedding_original_max_positions = rotary_embedding_max_position_info[1];
626653

654+
// 3. AttentionOp创建和初始化
627655
auto op = std::make_shared<AttentionOp>();
628656
op->mType = dtype;
629657
op->mFMHAForceFP32Acc = dtype == nvinfer1::DataType::kBF16;
@@ -709,6 +737,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
709737
= chunked_prefill_buffer_batch_size.has_value() ? chunked_prefill_buffer_batch_size.value() : 1;
710738
}
711739

740+
// 4. 缓存检查和初始化
712741
auto cache_key = std::make_tuple(op->data(), runner->data());
713742
using CacheKey = decltype(cache_key);
714743
static std::unordered_map<CacheKey, std::shared_ptr<AttentionOp>, hash<CacheKey>> op_cache;
@@ -726,6 +755,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
726755
op_cache[cache_key] = op;
727756
}
728757

758+
// 5. 请求类型和输入类型检测:ctx, gen: for continuous batching
729759
int32_t const num_seqs = host_context_lengths.size(0);
730760
RequestType const* request_types = static_cast<RequestType const*>(host_request_types.data_ptr());
731761

@@ -758,6 +788,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
758788
TLLM_CHECK(request_types[idx] == RequestType::kGENERATION);
759789
}
760790

791+
// 6. 工作空间管理
761792
int32_t const max_attention_window_size
762793
= beam_width == 1 ? attention_window_size : cache_indirection.value().size(2);
763794
int32_t const max_blocks_per_sequence
@@ -805,7 +836,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
805836
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
806837
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
807838
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
808-
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets);
839+
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
840+
cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
809841
}
810842

811843
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -822,7 +854,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
822854
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
823855
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
824856
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
825-
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets);
857+
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
858+
cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
826859
}
827860

828861
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);

cpp/tensorrt_llm/thop/attentionOp.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
6161
std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t> attention_chunk_size,
6262
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
6363
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
64-
std::vector<std::optional<torch::Tensor>> sparse_attention_params);
64+
std::vector<std::optional<torch::Tensor>> sparse_attention_params, std::optional<torch::Tensor> cu_q_seqlens,
65+
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
66+
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
67+
std::optional<torch::Tensor> quant_q_buffer);
6568

6669
} // namespace torch_ext

0 commit comments

Comments
 (0)