Skip to content

Conversation

@zhhuang-nv
Copy link
Collaborator

@zhhuang-nv zhhuang-nv commented Aug 1, 2025

Description

Use separate qkv input layout for context MLA, including default path, kv cache reuse path and chunked context path.

We can eliminate several extra operations like concat in default path, set_paged_kv in kv cache reuse path and set_chunked_kv in chunked context path. We can also eliminate the copy of tensor V by setting correct stride for it.

This PR also adds fmha_v2 kernels with separate qkv support for sm120. After this PR, the support matrix of DeepSeek V3/R1 is:

  • Default: SM90/SM100/SM120
  • KV Cache Reuse: SM90/SM100/SM120 (enabled by default)
  • Chunked Context: SM90/SM100 (disabled by default)

Benchmark on 8*B200 with nvidia/DeepSeek-R1-0528-FP4, enable FP8 KV Cache, ISL=1K, OSL=2K, num-requests=114688 (following this doc)

  • Total Token Throughput (tokens/sec) before this PR: 74324.6462
  • Total Token Throughput (tokens/sec) after this PR: 75218.5976 (1.2% improvement)

The speedup is trivial, because this PR only focus on context phase, while generation phase dominates this test case and the max_num_tokens is only 2048.

P.S. After modifying OSL to 1 and num-requests to 49152, we got 3.4% throughput improvement (76440.4384 VS 79022.3162).

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Summary by CodeRabbit

  • New Features

    • Separate Q/K/V input support, new SEPARATE_Q_K_V layout, and explicit total-KV-length tracking (host_total_kv_lens).
  • Refactor

    • Rewrote MLA/attention flows to use per-tensor Q/K/V buffers and unified context/generation wiring; kernel selection and tiling updated for separate Q/K/V.
  • Bug Fixes

    • Improved handling of non‑contiguous V layouts, strides, and shape correctness.
  • Chores / Removals

    • Removed legacy paged/chunked MLA KV-cache APIs, kernels, tests, and corresponding Torch/Python bindings.
  • Performance / FP8

    • New FP8 quantization path with dedicated per‑tensor FP8 buffers for Q/K/V.

@zhhuang-nv zhhuang-nv requested a review from a team as a code owner August 1, 2025 04:42
@zhhuang-nv zhhuang-nv requested review from hyukn and nv-yilinf August 1, 2025 04:42
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 1, 2025

📝 Walkthrough

Walkthrough

Refactors MLA/context attention to use separate Q, K, V buffers (including per-tensor FP8 quant buffers), removes MLA paged/chunked KV-cache setter kernels and helpers, introduces AttentionInputLayout::SEPARATE_Q_K_V across FMHA runners/kernels, updates C++/Python bindings and metadata to carry total_kv_len / host_total_kv_lens, and adapts tests and kernel selection/tiles accordingly.

Changes

Cohort / File(s) Change Summary
Attention core (C++ / ThOP)
cpp/tensorrt_llm/common/attentionOp.cpp, cpp/tensorrt_llm/common/attentionOp.h, cpp/tensorrt_llm/thop/attentionOp.cpp, cpp/tensorrt_llm/thop/attentionOp.h
Replace fused QKV buffer with separate q_buf/k_buf/v_buf and per-tensor FP8 quant buffers; allocate fp8_q/k/v workspace when enabled; add total_kv_len and optional k_ptr/v_ptr; update enqueueContext/getWorkspaceSizeForContext and signatures/serialization.
FMHA runner & params
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp, cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h, cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp, cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h, cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
Add AttentionInputLayout::SEPARATE_Q_K_V; add pointer members k_ptr/v_ptr and headSizeQkNope/mHeadDimQkNope; adjust setupLaunch/setTma/stride logic and softmax gating to support separate/non-contiguous V and pointer-based Q/K/V.
MLA kernels & params header
cpp/tensorrt_llm/kernels/mlaKernels.cu, cpp/tensorrt_llm/kernels/mlaKernels.h
Remove paged/chunked KV setter kernels/traits; refactor RoPE Q/K/V kernel to accept separate q_ptr/k_ptr; add templated FP8 quantize kernel and host API invokeMLAContextFp8Quantize; restructure MlaParams<T> to separate q/k/v and quant buffers; update instantiations.
Chunked/paged prefill removal
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu, cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
Delete setChunkedKV host wrapper, kernel and header decls; consolidate chunked prefill via unified load path; adjust instantiations and dropped BF16 instantiation in one macro.
FMHA v2 tiles, traits & kernel selection
cpp/kernels/fmha_v2/setup.py, cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h, cpp/kernels/fmha_v2/src/fmha/kernel_traits.h, cpp/kernels/fmha_v2/src/fmha/fused_multihead_attention.cpp, cpp/kernels/fmha_v2/fmha_test.py
Introduce SEPARATE_Q_K_V layout, create Gmem_tile_q_k_v tile and Kernel_traits_v2_q_k_v alias; update kernel naming/encoding, selection/gating, and tests to include separate-q-k-v MLA paths and adjust -save-softmax gating.
Dispatcher / generation FMHA glue
cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp, .../fmhaRunnerParams.h, .../kernelParams.h
Map SEPARATE_Q_K_VSeparateQkv; propagate kPtr/vPtr and new head-dim into generation runner params; pass dtypeKv to stride builders to support non-contiguous V.
Unit tests removed/updated
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu, cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
Remove SetChunkedKV/SetPagedKV test kernels, buffers and test cases; keep remaining MLA tests aligned to new paths.
Python backend & metadata
tensorrt_llm/_torch/attention_backend/trtllm.py, tensorrt_llm/_torch/metadata.py, tensorrt_llm/_torch/pyexecutor/*
Add host_total_kv_lens to wrapper/metadata and plan/run flows; remove set_paged/set_chunked MLA setters; load_paged/chunked now return (kv, k_pe) tuples; rename enable_paged_context_mlaenable_context_mla_with_cached_kv; change KVCacheParams.num_extra_kv_tokens type to Optional[int].
Python MLA module & tests
tensorrt_llm/_torch/modules/attention.py, tests/unittest/_torch/test_attention_mla.py
Remove fused-QKV concat helper and fused path; reconstruct K from NOPE+RoPE per-chunk and pass separate K/V into attention; update tests and metadata usage (new enable_flash_mla in test metadata).
Bindings (pybind / nanobind / thop)
cpp/tensorrt_llm/nanobind/thop/bindings.cpp, cpp/tensorrt_llm/pybind/thop/bindings.cpp, cpp/tensorrt_llm/thop/attentionOp.cpp, cpp/tensorrt_llm/thop/attentionOp.h
Add Python/C++ binding arg host_total_kv_lens; remove mla_context_paged_kv and mla_context_kv_cache_block_offsets; update attention binding signatures to accept separate k/v and total_kv_len.
FMHA cubin metadata & binaries
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/*
Replace/rename many 192x128 qkv/paged_kv entries to q_k_v softmax variants; add q_k_v cubin entries; numerous Git LFS pointer updates/removals (binary-asset metadata-only changes).
Misc / small edits
tensorrt_llm/_torch/models/modeling_deepseekv3.py, cpp/tensorrt_llm/common/attentionOp.h, cpp/tensorrt_llm/common/attentionOp.h
Minor formatting; add total_kv_len to EnqueueParams and k_ptr/v_ptr to EnqueueContextParams; add headSizeQkNope and AttentionInputLayout::SEPARATE_Q_K_V public members.

Sequence Diagram(s)

sequenceDiagram
    participant User as Python caller
    participant PyLayer as Py MLA/attention layer
    participant Backend as TrtllmAttentionWrapper
    participant CppOp as AttentionOp (C++)
    participant FMHA as FMHA runner / kernels

    User->>PyLayer: forward(q, k?, v?, ...)
    PyLayer->>Backend: plan/run(qkv_or_q, k?, v?, host_total_kv_lens, ...)
    Backend->>CppOp: run(..., qkv_or_q, k?, v?, total_kv_len, ...)
    CppOp->>FMHA: enqueueContext(q_ptr, k_ptr, v_ptr, quant_q/k/v buffers?, headSizeQkNope, layout=SEPARATE_Q_K_V)
    FMHA->>FMHA: prepare kernel params (strides, dtypeKv, head dims) and launch kernels
    FMHA-->>CppOp: results (output, optional softmax stats)
    CppOp-->>Backend: return outputs
    Backend-->>PyLayer: return outputs
    PyLayer-->>User: final output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

SW Architecture

Suggested reviewers

  • hlu1
  • QiJune
  • hyukn
  • Shixiaowei02
  • yechank-nvidia

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@zhhuang-nv zhhuang-nv changed the title [feat] Use Separate QKV Input Layout for Context MLA [None][feat] Use Separate QKV Input Layout for Context MLA Aug 1, 2025
@zhhuang-nv
Copy link
Collaborator Author

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🔭 Outside diff range comments (1)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)

1-3: Update copyright year to 2025

The copyright header should include the current year (2025) according to coding guidelines.

 /*
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2019-2025, NVIDIA CORPORATION.  All rights reserved.
  *
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)

681-681: Remove commented debug print statement.

This appears to be leftover debugging code.

-        # print("prepare TrtllmAttentionMetadata")
         extra_attrs = get_model_extra_attrs()
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ad5742b and fdfba9e.

📒 Files selected for processing (20)
  • cpp/tensorrt_llm/common/attentionOp.cpp (7 hunks)
  • cpp/tensorrt_llm/common/attentionOp.h (3 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (2 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h (4 hunks)
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp (5 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (1 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh (0 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.cu (4 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.h (1 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h (1 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h (1 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.cpp (14 hunks)
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (0 hunks)
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (0 hunks)
  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (13 hunks)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (4 hunks)
  • tensorrt_llm/_torch/metadata.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (0 hunks)
  • tensorrt_llm/_torch/modules/attention.py (6 hunks)
💤 Files with no reviewable changes (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/metadata.py
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. The block should be prepended to the top of all files, including .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/metadata.py
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
  • cpp/tensorrt_llm/common/attentionOp.h
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • tensorrt_llm/_torch/modules/attention.py
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
  • cpp/tensorrt_llm/kernels/mlaKernels.cu
  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
  • cpp/tensorrt_llm/kernels/mlaKernels.h
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{cpp,h,hpp,cc,cxx}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camelcase prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope and function-scope magic-number/literal constants are uppercase snakecas...

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
  • cpp/tensorrt_llm/kernels/mlaKernels.h
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{h,hpp}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.

Files:

  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
  • cpp/tensorrt_llm/kernels/mlaKernels.h
🧠 Learnings (3)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T04:09:12.904Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.

Applied to files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : all tensorrt-llm open source software code should contain...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T04:09:12.904Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. The block should be prepended to the top of all files, including .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Applied to files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • tensorrt_llm/_torch/attention_backend/trtllm.py
🧬 Code Graph Analysis (2)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (1)
cpp/tensorrt_llm/kernels/multiHeadAttentionCommon.h (2)
  • get_size_in_bytes (81-95)
  • get_size_in_bytes (99-102)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (4)
  • void (75-104)
  • void (107-136)
  • void (141-224)
  • void (229-295)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py

201-201: Line too long (148 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (48)
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h (1)

239-240: LGTM: Clear addition of MLA-specific parameter

The addition of mHeadDimQkNope with appropriate documentation clearly indicates its purpose for MLA context handling. The parameter naming follows the established camelCase convention and the comment provides sufficient context.

tensorrt_llm/_torch/metadata.py (1)

31-31: LGTM: Simplified type annotation aligns with refactoring

The change from Optional[List[int]] to Optional[int] simplifies the interface and is consistent with the broader refactoring to handle separate Q, K, V tensors. The field name and default value remain appropriate.

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h (1)

361-365: LGTM: Correct stride calculation for non-contiguous V tensor

The new conditional branch properly handles the stride calculation for the V tensor in context MLA when using separate QKV layout. The calculation options.mNumHeadsKv * (options.mHeadDimQkNope + options.mHeadDimV) correctly accounts for the non-contiguous layout by including both dimensions. The comment clearly explains the context-specific requirement.

cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp (4)

39-42: LGTM: Proper mapping for new SEPARATE_Q_K_V layout

The new case correctly maps AttentionInputLayout::SEPARATE_Q_K_V to QkvLayout::SeparateQkv, maintaining consistency with the established pattern for layout conversions.


155-158: LGTM: Consistent layout handling in dispatcher

The additional case for SEPARATE_Q_K_V layout is properly handled, setting the qkvLayout to SeparateQkv as expected for the new attention input layout.


172-173: LGTM: Explicit assignment of separate K and V pointers

The assignment of kPtr and vPtr from runnerParams enables proper handling of separate K and V tensors, replacing the previous nullptr assignments and supporting the new separate QKV input layout.


191-191: LGTM: Proper propagation of MLA-specific parameter

The assignment of mHeadDimQkNope from mFixedParams.headSizeQkNope correctly propagates the MLA-specific head dimension parameter to the runner parameters.

cpp/tensorrt_llm/common/attentionOp.h (3)

96-96: LGTM: Added necessary KV length tracking

The addition of total_kv_len to EnqueueParams provides essential tracking for total key-value length, supporting the new separate QKV input layout functionality.


130-132: LGTM: Well-documented separate K and V pointer support

The addition of optional k_ptr and v_ptr members is properly documented, clearly indicating their use for separate QKV input specifically for context MLA. The optional nature and specific use case are well communicated.


184-185: LGTM: Complete debug output coverage

The addition of the new K and V pointers to the debug string output ensures comprehensive logging and troubleshooting capabilities for the new separate QKV functionality.

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (2)

181-202: LGTM! Well-structured separate QKV layout implementation.

The implementation correctly handles the new SEPARATE_Q_K_V layout by setting separate pointers and calculating appropriate strides. The conditional logic for V tensor stride based on headSizeQkNope properly handles both contiguous and non-contiguous V tensor layouts, which aligns with the MLA context requirements.


633-638: LGTM! Consistent TMA descriptor setup for separate QKV.

The implementation correctly sets the K and V pointers from kernel parameters for the SEPARATE_Q_K_V layout, maintaining consistency with the established pattern used for other layout types.

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)

896-896: LGTM! Consistent addition of total_kv_lens parameter.

The total_kv_lens parameter is correctly added to the function signature, passed through to the underlying attention_inplace operation, and included in the fake registration. This supports the transition to explicit KV length tracking.

Also applies to: 980-980, 1010-1010


1060-1060: LGTM! Parameter ordering alignment.

The reordering of softmax_stats_tensor in the fake registration ensures consistency with the function signature and underlying operation interface.

cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (1)

353-353: LGTM! Cleanup aligns with PR objectives.

The removal of MLA chunked KV cache setting kernels aligns with the PR's goal of eliminating these specialized operations in favor of separate QKV input handling. The remaining functionality for merging attention and loading chunked KV is preserved as expected.

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h (4)

84-87: LGTM: Well-documented enum extension for separate QKV layout.

The addition of SEPARATE_Q_K_V to the AttentionInputLayout enum is well-implemented with clear documentation describing the layout structure and its specific use case for context MLA.


121-122: LGTM: Appropriate addition of MLA-specific parameter.

The headSizeQkNope member is correctly added with proper documentation indicating its MLA-specific usage for the Q/K non-RoPE part dimension.


174-174: LGTM: Consistent string representation for new enum value.

The string conversion method correctly handles the new SEPARATE_Q_K_V layout with appropriate naming that matches the enum value.


264-267: LGTM: Logical addition of separate K and V pointers.

The new kPtr and vPtr members are appropriately added to support the separate QKV input layout, complementing the existing qPtr and kvPtr pointers with clear documentation.

cpp/tensorrt_llm/kernels/mlaKernels.h (1)

55-57: LGTM: Clean refactoring to separate QKV buffers.

The modification of MlaParams to support separate Q, K, V buffers is well-implemented:

  • Clear documentation of the new buffer layouts and dimensions
  • Logical separation of concerns with dedicated buffers for context MLA
  • Consistent naming conventions following the existing pattern
cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1)

119-119: LGTM: Appropriate test cleanup for removed functionality.

The modification reflects the removal of test infrastructure for the deleted paged KV cache setting kernel. The remaining test coverage for the loadPagedKV functionality is still intact and appropriate.

cpp/tensorrt_llm/common/attentionOp.cpp (4)

1629-1649: Clear documentation for the new separate QKV input layout

The extended comment properly documents the new separate QKV input layout for context MLA, including all required pointers and parameters. This improves code maintainability.


1663-1676: Correct implementation of separate QKV buffer pointers for MLA

The conditional assignment of FMHA device buffer pointers properly handles both MLA (separate Q, K, V) and non-MLA (packed QKV or Q-only) cases, maintaining backward compatibility.


1687-1688: Good simplification of paged KV cache assignment

The direct assignment eliminates unnecessary conditional logic, improving code clarity and aligning with the PR's goal of removing extra operations.


2537-2547: Proper initialization of MLA-specific FMHA parameters

The initialization correctly sets the attention input layout to SEPARATE_Q_K_V and includes the new headSizeQkNope parameter when MLA is enabled, completing the support for separate QKV input layout.

cpp/tensorrt_llm/kernels/mlaKernels.cu (4)

187-189: LGTM! Function signature properly updated for separate Q/K buffers

The change from a single qkv_output to separate q_ptr and k_ptr aligns with the PR objective to eliminate concatenation operations in the default path.


242-246: Correct update to use separate query buffer

The change properly loads query data from the dedicated q_ptr buffer instead of fuse_buf, maintaining the same offset calculations.


826-828: Kernel invocation correctly updated for separate K buffer

The change from params.latent_cache to params.k_buf as the second parameter correctly matches the updated kernel signature that expects separate Q and K pointers.


276-282: Ignore identical Q/K indexing concern

The matching dst_q_idx and dst_k_idx formulas are intentional: both Q and K undergo the same RoPE offset calculation, and because q_ptr and k_ptr point to separate output buffers, the identical index correctly places each tensor’s data in its own memory. No change is needed here.

Likely an incorrect or invalid review comment.

tensorrt_llm/_torch/modules/attention.py (6)

944-949: Clean implementation of explicit K tensor construction

The code correctly constructs the full key tensor by:

  1. Reshaping full_k_nope to separate head dimensions
  2. Broadcasting full_k_pe across all attention heads
  3. Concatenating along the feature dimension
  4. Flattening to the expected shape

This explicit approach successfully replaces the previous fused QKV handling.


960-969: Correct adaptation to separate K/V tensor interface

The mha.forward call properly passes separate full_k and full_v tensors, and the explicit None for latent_cache with explanatory comment helps distinguish this cached KV code path.


1015-1015: Proper handling of total_kv_lens during chunked attention

The code correctly:

  1. Saves the original total_kv_lens[0] value
  2. Temporarily updates it for each chunk and the final attention computation
  3. Restores the original value after processing

This pattern ensures the attention kernels receive the correct KV lengths for each processing stage.

Also applies to: 1060-1061, 1105-1108, 1131-1131


1033-1054: Consistent K tensor construction with proper memory management

The code follows the established pattern for constructing K tensors and includes proper cleanup of intermediate variables to free memory. The explicit tensor reshaping and concatenation correctly implements the separate Q/K/V approach.


1089-1100: Consistent implementation for uncached KV handling

The K tensor construction for uncached KV follows the same proven pattern, ensuring consistency across all code paths in the chunked prefill implementation.


1112-1115: Final attention call correctly uses separate K/V tensors

The last mha.forward call in the chunked prefill path properly passes separate k and v tensors, maintaining consistency with the new separate QKV architecture.

tensorrt_llm/_torch/attention_backend/trtllm.py (8)

26-26: LGTM! Addition of total_kv_lens tensor.

The addition of total_kv_lens tensor to track total key-value lengths is well-integrated with the existing tensor attributes.


157-157: Well-documented addition of total_kv_lens parameter.

The total_kv_lens parameter is properly added to the plan() method with clear documentation explaining its purpose and shape (2) on CPU.

Also applies to: 201-201, 228-228


328-340: Correct implementation of separate QKV inputs for MLA context.

The changes properly enforce separate Q, K, V inputs for MLA context mode while maintaining fused QKV for generation mode. The hidden size calculation is correctly updated to use only Q dimensions.


412-412: LGTM! Correctly passing total_kv_lens to attention operation.


603-603: Proper initialization and computation of total KV lengths.

The total_kv_lens tensor is correctly initialized with shape (2) on CPU and properly computed by summing KV lengths for context and generation requests separately.

Also applies to: 728-731


1130-1130: LGTM! Propagating total_kv_lens through forward method.


1225-1225: Good use of modern Python type annotation.

Using lowercase tuple is the preferred syntax for Python 3.9+.


1267-1279: Good edge case handling for empty cached tokens.

The addition of early return with properly shaped empty tensors when there are no cached tokens is an efficient optimization that avoids unnecessary kernel calls.

cpp/tensorrt_llm/thop/attentionOp.cpp (5)

66-71: Well-documented interface update for separate QKV inputs.

The comment clearly explains the typical single QKV input case and the exception for context MLA using separate inputs. The addition of optional k and v tensors with total_kv_len parameter properly extends the interface.


137-186: Correct implementation of separate K, V buffer handling for MLA context.

The code properly:

  1. Renames the variable to qkv_or_q for clarity
  2. Validates K, V tensors are provided for MLA context
  3. Checks tensor dimensions and strides
  4. Correctly extracts and assigns buffer pointers

290-290: Proper propagation of new parameters through enqueue flow.

The total_kv_len and separate K, V pointers are correctly added to the enqueue parameters structure.

Also applies to: 307-308


415-416: Comprehensive update to support separate QKV inputs.

The changes correctly:

  1. Add total_kv_lens parameter to track KV lengths
  2. Ensure backward compatibility by requiring fused QKV for non-MLA
  3. Extract and use separate total KV lengths for context and generation

Also applies to: 443-445, 617-618


653-653: Consistent updates for renamed variables and new parameters.

The workspace tensor creation correctly uses the renamed qkv_or_q variable, and the TORCH_LIBRARY definition properly includes the new total_kv_lens parameter.

Also applies to: 754-754

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13764 [ run ] triggered by Bot

@zhhuang-nv
Copy link
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13796 [ kill ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13764 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13796 [ kill ] completed with state SUCCESS
Successfully killed previous jobs for commit fdfba9e

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)

1-15: Update copyright year to 2025.

According to the coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header that includes the current year.

 /*
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2019-2025, NVIDIA CORPORATION.  All rights reserved.
  *
🧹 Nitpick comments (2)
tensorrt_llm/_torch/attention_backend/trtllm.py (2)

681-681: Remove commented-out debug print statement.

-        # print("prepare TrtllmAttentionMetadata")

201-201: Line exceeds maximum length of 120 characters.

-            total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU.
+            total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and 
+                                         generation requests, with shape (2) on CPU.
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fdfba9e and d47541e.

📒 Files selected for processing (20)
  • cpp/tensorrt_llm/common/attentionOp.cpp (7 hunks)
  • cpp/tensorrt_llm/common/attentionOp.h (3 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (2 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h (4 hunks)
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp (5 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (1 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh (0 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.cu (4 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.h (1 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h (1 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h (1 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.cpp (14 hunks)
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (0 hunks)
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (0 hunks)
  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (13 hunks)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (4 hunks)
  • tensorrt_llm/_torch/metadata.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (0 hunks)
  • tensorrt_llm/_torch/modules/attention.py (5 hunks)
💤 Files with no reviewable changes (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
🚧 Files skipped from review as they are similar to previous changes (12)
  • tensorrt_llm/_torch/metadata.py
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/mlaKernels.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
  • tensorrt_llm/_torch/modules/attention.py
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • cpp/tensorrt_llm/kernels/mlaKernels.cu
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic...

Files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (2)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.

Applied to files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py

201-201: Line too long (148 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (13)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (1)

1-360: LGTM! Removal of MLA chunked KV cache setting kernels aligns with PR objectives.

The removal of setChunkedKVKernelTraits, setChunkedKVCacheForMLAKernel, and invokeMLASetChunkedKV successfully eliminates the MLA-specific chunked KV cache setting functionality as intended. The retained functions for merging attention and loading chunked KV cache appear to serve different purposes and are correctly preserved.

cpp/tensorrt_llm/kernels/mlaKernels.cu (2)

186-282: LGTM! Successful refactoring to support separate Q, K, V buffers.

The modification of applyMLARopeAndAssignQKVKernelOptContext to accept separate q_ptr and k_ptr parameters correctly implements the new SEPARATE_Q_K_V layout. The kernel properly writes Q data to the dedicated Q buffer and K data to the K buffer, eliminating the need for the removed paged KV cache setting operations.


825-829: LGTM! Clean removal of MLA paged KV cache setting functionality.

The removal of setPagedKVKernelTraits, setPagedKVCacheForMLAKernel, and invokeMLASetPagedKV successfully eliminates the MLA-specific paged KV cache setting functionality as intended. The kernel invocation correctly uses the new separate buffer interface.

tensorrt_llm/_torch/attention_backend/trtllm.py (3)

26-26: LGTM! Well-integrated addition of total_kv_lens tensor.

The new total_kv_lens tensor is properly integrated throughout the attention wrapper and metadata classes. It correctly tracks the total key-value lengths for context and generation requests separately, which aligns with the new separate Q, K, V buffer design.

Also applies to: 157-157, 228-228, 412-412, 728-731, 1130-1130


328-332: LGTM! Correct enforcement of separate Q, K, V inputs for MLA.

The updated assertions properly enforce that MLA context phase requires separate Q, K, V inputs (not is_fused_qkv) while generation phase still uses fused QKV. This aligns with the removal of fused QKV support in MLA context operations.


1273-1279: LGTM! Proper handling of empty tensor case in load_chunked_kv_cache_for_mla.

The function correctly returns empty tensors with appropriate shapes when there are no cached tokens, preventing potential runtime errors.

cpp/tensorrt_llm/thop/attentionOp.cpp (7)

66-83: Well-designed interface update for separate QKV support.

The function signature changes clearly support the new separate QKV input layout for MLA context processing. The comment effectively explains the dual usage pattern, and the parameter naming (qkv_or_q, k, v, total_kv_len) is descriptive and intuitive.


120-135: Consistent implementation of the updated interface.

The Runner::run method signature correctly matches the updated RunnerBase interface, ensuring proper virtual function override behavior.


137-210: Robust implementation of separate QKV tensor handling.

The implementation correctly handles both fused and separate QKV tensor inputs with appropriate validation:

  • Proper dimensional and stride checks for K/V tensors in MLA context mode
  • Consistent pointer initialization and assignment
  • Clear branching logic between context and generation modes
  • Memory layout assumptions are properly validated

The validation checks ensure runtime safety while maintaining performance-critical assumptions about tensor layout.


290-290: Correct parameter propagation to enqueue operations.

The new total_kv_len, k_ptr, and v_ptr parameters are properly assigned to the enqueue parameter structures, ensuring the separate tensor information flows correctly through the attention pipeline.

Also applies to: 307-308


412-436: Well-structured function signature update.

The addition of the total_kv_lens parameter is logically positioned with other length-related parameters, maintaining good parameter organization and readability.


443-445: Comprehensive implementation with robust validation.

The implementation correctly handles the transition to separate QKV support:

  • Clear validation ensures either MLA is enabled or fused QKV is used, preventing invalid configurations
  • Consistent parameter renaming (qkv_or_q) throughout the function
  • Proper extraction and passing of total_kv_lens values for both context and generation stages
  • Informative error messages guide users toward correct usage

The dual-path support (context and generation) maintains the same parameter passing pattern, ensuring consistency.

Also applies to: 614-685


754-754: Consistent Torch library binding update.

The addition of the total_kv_lens parameter to the Torch library definition correctly reflects the updated C++ function signature, maintaining proper Python-to-C++ interface consistency.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)

201-201: Line exceeds maximum length limit.

Line 201 is 148 characters long, exceeding the 120 character limit.

Consider breaking this line for better readability:

-            total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU.
+            total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and 
+                                         generation requests, with shape (2) on CPU.
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d47541e and e8bc4b3.

📒 Files selected for processing (2)
  • cpp/tensorrt_llm/kernels/mlaKernels.cu (5 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (12 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
  • tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
🧠 Learnings (3)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : all tensorrt-llm open source software code should contain...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Applied to files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.

Applied to files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py

201-201: Line too long (148 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (5)
cpp/tensorrt_llm/kernels/mlaKernels.cu (2)

2-2: Copyright year correctly updated.


187-282: Kernel correctly updated for separate Q/K/V layout.

The changes properly implement the separation of QKV tensors:

  • Function signature updated to accept separate q_ptr and k_ptr parameters
  • Memory access patterns correctly adjusted for reading from and writing to separate buffers
  • Kernel invocation updated to pass params.k_buf instead of params.latent_cache

Also applies to: 826-826

tensorrt_llm/_torch/attention_backend/trtllm.py (3)

26-26: Total KV length tracking correctly implemented.

The addition of total_kv_lens tensor properly tracks the total key-value lengths for context and generation requests separately, which aligns with the PR objectives.

Also applies to: 157-157, 228-228, 412-412, 603-603, 727-730, 1129-1129


328-333: MLA context mode correctly enforces separate Q/K/V inputs.

The assertions properly ensure that MLA context mode uses separate inputs instead of fused QKV, aligning with the PR's goal of eliminating extra operations.


1224-1278: Load functions correctly updated to return separate tensors.

The changes to load_paged_kv_cache_for_mla and load_chunked_kv_cache_for_mla properly return tuples of (compressed_kv, k_pe) instead of single tensors, supporting the new separate Q/K/V layout. The empty tensor handling for the no-cache case is also correctly implemented.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)

201-201: Consider breaking long documentation line for better readability.

The documentation line exceeds the 120-character limit (148 characters). Consider reformatting:

-            total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU.
+            total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests 
+                and generation requests, with shape (2) on CPU.
cpp/tensorrt_llm/thop/attentionOp.cpp (1)

443-444: Clarify the validation logic for better readability.

The validation condition is_mla_enable || is_fused_qkv with the error message "Only fused QKV is supported for non-MLA attention now" could be confusing. Consider making the intent clearer:

-TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Only fused QKV is supported for non-MLA attention now");
+TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Either MLA must be enabled or fused QKV must be used");

Or restructure the logic to be more explicit about the requirements for each case.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e8bc4b3 and 755c03f.

📒 Files selected for processing (3)
  • cpp/tensorrt_llm/kernels/mlaKernels.cu (5 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.cpp (14 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (12 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic...

Files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (3)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : all tensorrt-llm open source software code should contain...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Applied to files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.

Applied to files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py

201-201: Line too long (148 > 120)

(E501)

🔇 Additional comments (12)
cpp/tensorrt_llm/kernels/mlaKernels.cu (4)

2-2: LGTM! Copyright header correctly updated.

The copyright year range has been properly updated to include 2025, following the coding guidelines requirement for NVIDIA copyright headers.


187-187: LGTM! Kernel signature correctly updated for separate QKV layout.

The function signature change from a single qkv_output pointer to separate q_ptr and k_ptr parameters aligns perfectly with the PR objective of using separate QKV input layout for context MLA, eliminating the need for concatenation operations.


242-281: LGTM! Memory access patterns correctly updated for separate buffers.

The memory access changes properly separate query and key operations:

  • Query data is now correctly read from q_ptr instead of the previous unified buffer
  • Output writes to both q_ptr and k_ptr maintain the same indexing logic while targeting separate buffers
  • Vector operations and mathematical calculations remain consistent

This change supports the elimination of concatenation operations as intended by the PR.


826-826: LGTM! Kernel invocation correctly updated for separate QKV buffers.

The kernel launch parameters have been properly updated to pass separate params.attention_input_buf and params.k_buf arguments, matching the modified kernel signature that now expects separate q_ptr and k_ptr parameters.

tensorrt_llm/_torch/attention_backend/trtllm.py (4)

26-26: LGTM! Comprehensive addition of total_kv_lens tracking.

The total_kv_lens tensor attribute has been properly added across all relevant components:

  • Correctly declared in the dataclass with appropriate type annotation
  • Added to method signatures with clear documentation explaining its purpose
  • Properly initialized as a 2-element CPU tensor to track context and generation KV lengths separately
  • Computation logic correctly sums KV lengths excluding extra tokens
  • Consistently passed through all call chains

This change effectively supports the PR objective of explicit total KV length tracking instead of relying on MLA paged KV cache mechanisms.

Also applies to: 157-157, 201-201, 228-228, 412-412, 603-603, 728-730, 1129-1129


328-332: LGTM! MLA context handling correctly updated for separate QKV layout.

The assertion logic properly enforces the new MLA behavior:

  • Context-only mode now requires separate QKV (assert not is_fused_qkv), supporting the PR's separate QKV input layout
  • Generation-only mode maintains fused QKV behavior (assert is_fused_qkv) for consistency
  • Return type annotations updated to modern Python syntax with tuple[torch.Tensor, torch.Tensor]

These changes align perfectly with the PR objective of eliminating concatenation operations in the context path.

Also applies to: 1224-1224, 1266-1266


70-70: LGTM! Softmax statistics tensor properly added.

The softmax_stats_tensor attribute is correctly declared with appropriate typing and follows the established pattern for optional tensor attributes in the wrapper class.


1272-1278: LGTM! Empty tensor creation properly handled for edge case.

The empty tensor creation logic correctly handles the case when max_ctx_cached_token_len == 0:

  • Proper dimensions using MLA parameters for both KV and K_PE tensors
  • Maintains dtype and device consistency
  • Returns appropriate tuple format matching the function signature

This ensures robust handling of edge cases in the chunked KV cache loading.

cpp/tensorrt_llm/thop/attentionOp.cpp (4)

66-83: LGTM! Well-structured function signature updates for separate QKV support.

The signature changes appropriately introduce optional k and v tensors alongside the renamed qkv_or_q parameter, and the addition of total_kv_len parameter aligns with the PR objectives. The comment clearly explains the rationale for separate QKV inputs in context MLA.


162-186: Excellent validation logic for MLA context processing.

The comprehensive tensor validation ensures proper input handling for separate QKV tensors:

  • Appropriate checks for tensor existence, dimensions, and memory layout
  • Correct handling of latent_cache for different MLA scenarios
  • Proper extraction of k_ptr and v_ptr for downstream processing

The validation follows defensive programming practices and maintains code robustness.


137-139: Consistent and clear variable renaming throughout the implementation.

The renaming from qkv to qkv_or_q accurately reflects the dual usage pattern and improves code clarity. The addition of total_kv_len parameter properly supports the new functionality requirements.

Also applies to: 290-290, 445-445, 457-457, 614-614, 653-653


412-415: Proper Torch library binding updates.

The binding definition correctly adds the total_kv_lens parameter and maintains consistency with the C++ function signature. The parameter placement and naming follow the established patterns.

Also applies to: 754-754

@zhhuang-nv
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13801 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13801 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10375 completed with status: 'FAILURE'

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tensorrt_llm/_torch/attention_backend/trtllm.py (2)

70-70: Consider using Optional typing for consistency.

The softmax_stats_tensor attribute should use Optional[torch.Tensor] typing to be consistent with other optional tensor attributes in the class.

-    softmax_stats_tensor: Optional[torch.Tensor]
+    softmax_stats_tensor: Optional[torch.Tensor]

201-201: Fix line length violation.

This line exceeds the 120-character limit specified in the coding guidelines.

The line should be wrapped or shortened to comply with the coding standards. Consider breaking the parameter documentation across multiple lines or shortening the description.

cpp/tensorrt_llm/thop/attentionOp.cpp (1)

191-191: Remove unnecessary blank line.

There's an extra blank line that should be removed for consistency with the codebase formatting.

                mla_params.latent_cache = static_cast<T const*>(latent_cache->data_ptr());
-
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 755c03f and 89cf5e1.

📒 Files selected for processing (19)
  • cpp/tensorrt_llm/common/attentionOp.cpp (7 hunks)
  • cpp/tensorrt_llm/common/attentionOp.h (3 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (2 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h (4 hunks)
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp (5 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (1 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh (0 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.cu (5 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.h (1 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h (1 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h (1 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.cpp (14 hunks)
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (0 hunks)
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (0 hunks)
  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (12 hunks)
  • tensorrt_llm/_torch/metadata.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (0 hunks)
  • tensorrt_llm/_torch/modules/attention.py (5 hunks)
💤 Files with no reviewable changes (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
🚧 Files skipped from review as they are similar to previous changes (11)
  • tensorrt_llm/_torch/metadata.py
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
  • tensorrt_llm/_torch/modules/attention.py
  • cpp/tensorrt_llm/kernels/mlaKernels.h
  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
  • cpp/tensorrt_llm/common/attentionOp.cpp
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • cpp/tensorrt_llm/kernels/mlaKernels.cu
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope and function-...

Files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (3)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : all tensorrt-llm open source software code should contain...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-04T02:12:17.550Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Applied to files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-04T02:12:17.550Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.

Applied to files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py

201-201: Line too long (148 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (31)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (2)

1-15: LGTM! Copyright header is compliant.

The copyright header correctly includes the current year (2025) and follows the required NVIDIA format for TensorRT-LLM Open Source Software.


353-353: LGTM! Template instantiation is correct.

The template instantiation for invokeMLALoadChunkedKV with FP8 cache type is properly formatted and syntactically correct.

cpp/tensorrt_llm/kernels/mlaKernels.cu (6)

2-2: LGTM! Copyright header updated correctly.

The copyright header properly includes the current year (2025) in compliance with TensorRT-LLM coding guidelines.


187-187: LGTM! Kernel signature updated for separate Q/K buffers.

The change from a single qkv_output pointer to separate q_ptr and k_ptr parameters aligns with the PR objective to use separate QKV input layout, eliminating concatenation operations.


242-243: LGTM! Q buffer offset calculation is correct.

The offset calculation properly computes the global token index multiplied by the head-specific stride, maintaining correct indexing for the separate Q buffer layout.


245-245: LGTM! Q buffer access is consistent.

The buffer access correctly uses the separate q_ptr with the computed offset, maintaining consistency with the new separate buffer layout.


276-282: LGTM! Output assignments correctly use separate Q/K buffers.

The index calculations and buffer writes properly handle the separate Q and K outputs, completing the transformation from fused QKV to separate buffer handling. The indexing is mathematically consistent and follows the expected memory layout.


826-826: LGTM! Function call updated to match new kernel signature.

The kernel launch correctly passes params.k_buf as the separate K buffer argument, maintaining consistency with the updated kernel signature that now accepts separate Q and K pointers.

tensorrt_llm/_torch/attention_backend/trtllm.py (10)

26-26: LGTM!

The addition of total_kv_lens tensor attribute is correctly integrated into the dataclass definition and follows the established pattern of other tensor attributes.


157-157: LGTM!

The parameter addition is correctly documented and follows the established parameter naming conventions.


228-228: LGTM!

The assignment of total_kv_lens parameter is correctly implemented and maintains consistency with other tensor assignments in the method.


418-418: LGTM!

The total_kv_lens parameter is correctly passed to the attention operation call, maintaining consistency with the method signature updates.


609-609: LGTM!

The initialization of total_kv_lens tensor with proper shape and device placement is correctly implemented.


733-736: LGTM!

The computation of total KV lengths for context and generation requests is correctly implemented. The logic properly sums KV lengths without extra tokens and stores them in the appropriate tensor indices.


1137-1137: LGTM!

The total_kv_lens parameter is correctly passed to the wrapper's plan method, maintaining consistency throughout the call chain.


1232-1232: LGTM!

The return type annotations have been updated from single tensor to tuple of two tensors, which aligns with the changes mentioned in the AI summary about returning separate K and V buffers.

Also applies to: 1274-1274


1280-1286: LGTM!

The handling of empty KV cache case is correctly implemented with proper tensor creation for both compressed KV and K position encoding tensors.


357-367: MLA QKV Handling Assertions Are Correct

Searches confirm that in tensorrt_llm/_torch/modules/attention.py and corresponding unit tests, AttentionInputType.context_only always invokes the multi-head attention (MHA) with separate q, k, v tensors, whereas generation_only uses a fused q tensor. The new assertions in trtllm.py simply enforce this existing behavior and align with the rest of the codebase. No changes required.

cpp/tensorrt_llm/thop/attentionOp.cpp (13)

66-83: LGTM!

The interface updates to support separate Q, K, V tensors and the new total_kv_len parameter are well-designed. The comment on line 66 clearly explains the context for using separate QKV inputs.


122-135: LGTM!

The Runner template class method signature correctly matches the base class interface with the new parameters for separate K/V tensors and total KV length.


137-141: LGTM!

The variable renaming from qkv to qkv_or_q better reflects the dual usage pattern, and the initialization of separate K/V pointers is correctly implemented.


290-290: LGTM!

The assignment of total_kv_len to the common enqueue parameters is correctly implemented and maintains the parameter flow through the attention pipeline.


307-308: LGTM!

The assignment of separate K and V pointers to the enqueue parameters for context stage is correctly implemented, supporting the new separate tensor input layout.


412-418: LGTM!

The function signature update to include separate K/V tensors and total_kv_lens parameter is correctly implemented and maintains consistency with the interface changes.


457-457: LGTM!

The variable rename from qkv to qkv_or_q is correctly applied throughout the function, maintaining consistency with the interface changes.


613-618: LGTM!

The extraction of total KV lengths for context and generation requests from the tensor is correctly implemented, providing the necessary values for separate processing paths.


652-652: LGTM!

The workspace tensor creation uses the renamed qkv_or_q variable consistently with other changes in the function.


662-667: LGTM!

The runner calls for both context and generation stages correctly pass the new parameters including separate K/V tensors and individual total KV lengths, maintaining the proper parameter flow through the attention pipeline.

Also applies to: 678-683


753-753: LGTM!

The torch library binding signature is correctly updated to include the total_kv_lens parameter, maintaining consistency with the C++ interface changes.


162-186: Double-check MLA Context Tensor Validation Consistency

Please confirm that the new C++ checks match the tensor shapes and memory layouts produced by the Python LLM API:

  • latent_cache
    • May be null for KV‐cache reuse or chunked contexts – ensure the Python layer omits or passes None as expected.
  • k and v tensors
    • Must be 2D (dim() == 2) with inner stride == 1 – verify that Python constructs these as contiguous [num_tokens, head_dim] buffers in context mode.
    • Slicing by token_offset (slice(0, token_offset)) should align with how tokens are buffered upstream.

If there’s any mismatch between these constraints and the Python‐side behavior (e.g. different stride layout, batched vs. flat dimensions), update either the C++ checks or the Python wrapper to ensure consistency.


443-445: Clarify MLA vs fused QKV assertion semantics

The current check in cpp/tensorrt_llm/thop/attentionOp.cpp (around lines 443–445) is written as an inclusive OR:

TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv,
    "Only fused QKV is supported for non-MLA attention now");
TLLM_CHECK_WITH_INFO(update_kv_cache,
    "KV cache update cannot be disabled now");
auto qkv_or_q = q;

This allows both is_mla_enable and is_fused_qkv to be true simultaneously. If the intention was to make these modes mutually exclusive (i.e., MLA excludes fused QKV and vice versa), the assertion should use XOR logic:

- TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv,
-     "Only fused QKV is supported for non-MLA attention now");
+ TLLM_CHECK_WITH_INFO(is_mla_enable != is_fused_qkv,
+     "MLA attention requires separate Q/K/V; non-MLA requires fused QKV");

Otherwise, if dual-mode operation (MLA + fused QKV) is acceptable, update the comment string to clarify that both flags may be set together.

• Location: cpp/tensorrt_llm/thop/attentionOp.cpp:443–445
• No other MLA vs. fused-QKV checks were found elsewhere in the codebase.
• The update_kv_cache assertion is only present here and currently cannot be disabled.

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1)

2-2: Update copyright header to include current year.

The copyright header shows 2022-2024, but the coding guidelines require including the current year (2025).

Apply this diff to fix the copyright header:

- * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION.  All rights reserved.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)

157-157: Fix line length violation in docstring.

The addition of total_kv_lens parameter is correct, but the docstring on line 201 exceeds the 120 character limit.

Apply this diff to fix the line length:

-            total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU.
+            total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context 
+                requests and generation requests, with shape (2) on CPU.

Also applies to: 201-201

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 89cf5e1 and ab6a170.

📒 Files selected for processing (19)
  • cpp/tensorrt_llm/common/attentionOp.cpp (7 hunks)
  • cpp/tensorrt_llm/common/attentionOp.h (3 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (2 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h (4 hunks)
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp (5 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (1 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh (0 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.cu (5 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.h (1 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h (1 hunks)
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h (1 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.cpp (14 hunks)
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (0 hunks)
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (0 hunks)
  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (12 hunks)
  • tensorrt_llm/_torch/metadata.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (0 hunks)
  • tensorrt_llm/_torch/modules/attention.py (5 hunks)
💤 Files with no reviewable changes (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
🚧 Files skipped from review as they are similar to previous changes (11)
  • tensorrt_llm/_torch/metadata.py
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
  • cpp/tensorrt_llm/kernels/mlaKernels.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
  • tensorrt_llm/_torch/modules/attention.py
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/kernels/mlaKernels.cu
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope and function-...

Files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (1)
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
  • tensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py

201-201: Line too long (148 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (26)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (2)

1-27: LGTM - Copyright header and includes are correct.

The copyright header properly includes 2025 as required by the coding guidelines, and the includes are appropriate for this CUDA kernel file.


343-357: LGTM - Template instantiation macro is correctly structured.

The macro properly instantiates the template functions for different data types and cache types. The modification on line 353 appears to be formatting-related and maintains the correct structure.

cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1)

93-507: LGTM - Test class structure is well-organized.

The test class properly structures the MLA preprocessing tests. The removal of paged KV cache related functionality aligns with the PR objectives of simplifying MLA context handling. The remaining test infrastructure for load operations appears comprehensive and well-designed.

tensorrt_llm/_torch/attention_backend/trtllm.py (8)

26-26: LGTM - Addition of total_kv_lens attribute is well-integrated.

The new total_kv_lens tensor attribute provides explicit tracking of total KV lengths for context and generation requests, which aligns with the PR objectives. The attribute is properly typed and integrated into the class structure.

Also applies to: 228-228


70-70: LGTM - Softmax stats tensor attribute is properly positioned.

The softmax_stats_tensor attribute is correctly typed and positioned within the dataclass structure.


357-367: LGTM - MLA input type handling correctly implements separate QKV inputs.

The updated assertions correctly enforce separate Q, K, V inputs for context_only processing and fused QKV for generation_only, which aligns with the PR objectives of using separate QKV input layout for context MLA.


418-418: LGTM - Addition of total_kv_lens to attention operation call.

The total_kv_lens parameter is correctly added to the attention operation call, providing explicit KV length tracking as required by the updated kernel interface.


609-609: LGTM - Proper initialization of total_kv_lens tensor.

The total_kv_lens tensor is correctly initialized with shape (2) on CPU to store separate totals for context and generation requests.


734-736: LGTM - Correct computation of total KV lengths.

The computation properly separates context and generation request totals, providing the kernels with accurate KV length information for each request type.


1232-1232: LGTM - Return type updates and empty tensor handling are correct.

The explicit tuple return types for load_paged_kv_cache_for_mla and load_chunked_kv_cache_for_mla correctly reflect the separate KV and positional embedding outputs. The empty tensor handling for the edge case when max_ctx_cached_token_len is 0 is good defensive programming.

Also applies to: 1274-1274, 1280-1286


1137-1137: LGTM - Integration of total_kv_lens in wrapper.plan call.

The total_kv_lens parameter is correctly passed to the wrapper's plan method, ensuring the attention operation has access to the explicit KV length information.

cpp/tensorrt_llm/thop/attentionOp.cpp (15)

66-83: LGTM - Clean interface extension for separate QKV inputs.

The virtual method signature is properly updated to support separate K and V tensors alongside the existing QKV tensor, with clear documentation of the use case.


120-135: LGTM - Implementation signature correctly matches base class.

The method signature properly implements the virtual base class interface with consistent parameter types and naming.


137-141: LGTM - Proper variable initialization for separate QKV support.

The renaming to qkv_or_q and initialization of separate K/V pointers correctly reflects the new dual-purpose interface.


162-186: LGTM - Well-structured MLA context handling with proper validation.

The implementation correctly handles separate K and V tensors for MLA context mode with appropriate validation checks for tensor dimensions, strides, and presence. The latent cache handling properly accounts for different scenarios (standard context vs. KV cache reuse/chunked context).


289-289: LGTM - Correct parameter assignment.

The total_kv_len parameter is properly assigned to the enqueue parameters structure.


306-307: LGTM - Proper pointer assignment for context stage.

The K and V pointers are correctly assigned to the enqueue parameters for the context processing stage.


411-435: LGTM - Function signature properly updated for new interface.

The attention_inplace function signature is correctly updated to include the total_kv_lens parameter, maintaining consistency with the runner interface changes.


442-444: LGTM - Important validation checks for configuration consistency.

The checks ensure that either MLA is enabled or fused QKV is used, and that KV cache update is always enabled. These validations prevent invalid attention configurations after the refactoring.


456-456: LGTM - Correct usage of renamed variable.

The data type extraction correctly uses the renamed qkv_or_q tensor.


651-651: LGTM - Proper device inference for workspace allocation.

The workspace tensor creation correctly uses the device from the qkv_or_q tensor.


612-612: LGTM - Correct token count calculation.

The number of tokens is properly calculated from the first dimension of the input tensor.


615-616: LGTM - Proper extraction of total KV lengths.

The context and generation total KV lengths are correctly extracted from the total_kv_lens tensor using appropriate indexing.


661-666: LGTM - Context runner call properly updated with new parameters.

The runner call correctly passes the renamed qkv_or_q tensor, separate k and v tensors, and the context total KV length parameter.


677-682: LGTM - Generation runner call properly updated with new parameters.

The runner call correctly passes the new parameters including the generation total KV length, maintaining consistency with the context call pattern.


752-752: LGTM - Torch library binding properly updated.

The function binding correctly includes the new total_kv_lens parameter, maintaining consistency with the updated C++ function signature.

@zhhuang-nv
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13920 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13920 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #10482 completed with status: 'FAILURE'

@zhhuang-nv
Copy link
Collaborator Author

/bot kill

Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
@zhhuang-nv
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15632 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🔭 Outside diff range comments (6)
cpp/tensorrt_llm/common/attentionOp.cpp (2)

770-772: Under-allocation: include mFP8ContextMLA in BMM scale workspace sizes

fmha_bmm1_scale_size and fmha_bmm2_scale_size only account for mFP8ContextFMHA. FP8 context MLA also consumes these buffers (you pass bmm1/bmm2 scales into params.mla_param and use them in invokeMLAContextFp8Quantize). This can under-allocate the context workspace and corrupt subsequent allocations.

Apply this fix:

-    size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
-    size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;
+    size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0;
+    size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0;

1655-1683: Null q/k/v pointers when latent_cache is absent break FP8 quantization

invokeMLAContextFp8Quantize() requires params.mla_param->q_buf/k_buf/v_buf to be set. In the default MLA context path (latent_cache == nullptr), these pointers are never initialized before the quant call, causing TLLM_CHECK failures or undefined behavior.

Set q/k/v pointers for the default path prior to quantization:

             params.mla_param->host_bmm1_scale
                 = 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
+            // Ensure q/k/v are set when latent_cache is absent (default separate Q/K/V input).
+            if (params.mla_param->q_buf == nullptr)
+            {
+                // attention_input already points to Q in context FMHA path.
+                params.mla_param->q_buf = const_cast<T*>(attention_input);
+            }
+            if (params.mla_param->k_buf == nullptr)
+            {
+                // k_ptr comes from THOP attention binding for separate K input.
+                params.mla_param->k_buf = const_cast<T*>(params.k_ptr);
+            }
+            if (params.mla_param->v_buf == nullptr)
+            {
+                // v_buf is const in MlaParams.
+                params.mla_param->v_buf = params.v_ptr;
+            }
             if (params.mla_param->latent_cache != nullptr)
             {
                 // compute RoPE and set compressed_kv + k_pe by invokeMLARopeContext if latent_cache is not nullptr
                 invokeMLARopeContext<T, KVCacheBuffer>(*params.mla_param, kv_cache_buffer, stream);
             }
             if (mFP8ContextMLA)
             {
                 invokeMLAContextFp8Quantize(*params.mla_param, params.total_kv_len, stream);
             }
tensorrt_llm/_torch/modules/attention.py (1)

1164-1172: Syntax error in type annotation for latent_cache

The argument annotation is split across lines, producing invalid Python syntax.

Apply this fix:

-        latent_cache: torch.
-        Tensor,  # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
+        latent_cache: torch.Tensor,  # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
cpp/tensorrt_llm/thop/attentionOp.cpp (3)

140-147: Guard against non-contiguous/invalid Q layout before taking raw pointer

The enqueue path assumes the inner-most dimension is contiguous. Add light shape/stride checks to fail-fast when Q is not laid out as expected.

Apply this diff:

-        auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
-        T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
+        auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
+        TORCH_CHECK(qkv_or_q.dim() >= 2, "qkv_or_q must be at least 2D: [tokens, hidden]");
+        TORCH_CHECK(qkv_or_q.strides()[qkv_or_q.dim() - 1] == 1, "qkv_or_q inner-most dimension must be contiguous");
+        T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());

652-655: Fix potential null deref when beam_width > 1 but cache_indirection is absent

This uses cache_indirection.value() without has_value() guard. Prefer the same safe pattern you already use inside Runner::run.

Apply this diff:

-    int32_t const max_attention_window_size
-        = beam_width == 1 ? attention_window_size : cache_indirection.value().size(2);
+    int32_t const max_attention_window_size = beam_width == 1
+        ? attention_window_size
+        : (cache_indirection.has_value() ? cache_indirection.value().size(2) : attention_window_size);

576-601: Missing has_value() checks for MLA params (q_lora_rank/kv_lora_rank/...): will crash if unset

You call .value() on several optionals without verifying has_value(). Add checks and provide a clear error message.

Apply this diff:

     if (is_mla_enable)
     {
         // MLA does not support NVFP4 output yet.
         TLLM_CHECK(!is_fp4_out);
 
         TLLM_CHECK(host_kv_cache_pool_mapping.has_value());
         int32_t const layer_num = host_kv_cache_pool_mapping.value().size(0);
 
+        TLLM_CHECK(q_lora_rank.has_value() && kv_lora_rank.has_value() && qk_nope_head_dim.has_value()
+                && qk_rope_head_dim.has_value() && v_head_dim.has_value(),
+            "MLA requires q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, and v_head_dim");
+
         op->mIsMLAEnabled = true;
         op->mMLAParams = {static_cast<int>(q_lora_rank.value()), static_cast<int>(kv_lora_rank.value()),
             static_cast<int>(qk_nope_head_dim.value()), static_cast<int>(qk_rope_head_dim.value()),
             static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
             static_cast<int>(layer_num)};
♻️ Duplicate comments (13)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h (8)

1850-1851: Document mask/layout fields for readability (avoid magic numbers)

The 0/1 (mask) and 3 (layout) columns correspond to AttentionMaskType and AttentionInputLayout respectively. Inline annotations improve clarity and reduce regressions.

Apply:

-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 213248, 384, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},

If this header is generated, consider emitting enum names (e.g., static_cast(AttentionInputLayout::SEPARATE_Q_K_V)) in the generator.


4333-4334: SM100 NL-tiled: annotate mask/layout and confirm shared wrapper + implementation

  • Add inline comments for mask/layout.
  • Both rows share run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled. Confirm this is intended for both mask kinds.
  • Ensure the wrapper’s definition exists (see extern at Line 1083).
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled},
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled},

4743-4744: SM120 NL-tiled BF16: annotate mask/layout and ensure wrapper definition exists

  • Add inline mask/layout comments.
  • Confirm the wrapper run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled (extern at Line 1224) is implemented and linked.
  • Validate that sharing the same wrapper for causal and non-causal is intentional.
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled},
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled},

212-216: SM90 separate Q-K-V wrappers: ensure implementations exist to avoid link errors

The new externs for SM90 separate Q-K-V context paths look correct, but prior verification found no corresponding definitions for these symbols, which will lead to linker failures if still missing:

  • run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90 (Line 212)
  • run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90 (Line 216)

Please add their definitions in the appropriate .cu/.cpp unit (consistent with other run_fmha_v2_* wrappers) and ensure runtime dispatch reaches them.

Run this to confirm definitions exist (expects at least one non-header definition per symbol):

#!/bin/bash
set -euo pipefail
syms=(
  run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90
  run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90
)
for s in "${syms[@]}"; do
  echo ">>> Checking definition for $s"
  if ! rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${s}[[:space:]]*\\("; then
    echo "!! Missing definition for: $s"
    exit 1
  fi
done
echo "All SM90 S_q_k_v wrappers found."

1083-1083: SM100 NL-tiled separate Q-K-V wrapper: definition likely missing

Extern added:

  • run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled

A previous review found no implementation. Add the definition and ensure both non-causal and causal registry entries route correctly to this wrapper.

#!/bin/bash
set -euo pipefail
sym='run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled'
rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${sym}[[:space:]]*\\(" || {
  echo "Missing definition for: $sym"
  exit 1
}

1224-1224: SM120 NL-tiled BF16 separate Q-K-V wrapper: add implementation

Extern added:

  • run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled

This was previously flagged as missing. Provide the .cu/.cpp implementation and include it in the build.

#!/bin/bash
set -euo pipefail
sym='run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled'
rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${sym}[[:space:]]*\\(" || {
  echo "Missing definition for: $sym"
  exit 1
}

1254-1258: SM120 E4M3 64×64 separate Q-K-V wrappers: add missing implementations

Externs added:

  • run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled (Line 1254)
  • run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled (Line 1258)

Previous checks found no definitions. Implement both wrappers and ensure they are linked.

#!/bin/bash
set -euo pipefail
syms=(
  run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled
  run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled
)
for s in "${syms[@]}"; do
  echo ">>> Checking definition for $s"
  rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${s}[[:space:]]*\\(" || {
    echo "Missing definition for: $s"
    exit 1
  }
done
echo "All SM120 E4M3 S_q_k_v wrappers found."

4833-4834: SM120 E4M3 separate Q-K-V rows: annotate mask/layout and implement wrappers

  • Add inline comments for mask/layout to avoid magic numbers.
  • Ensure both E4M3 wrappers are implemented:
    • run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled (Line 4833/4834)
    • run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled (Line 4842/4843)
-{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled},
+{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled},
-{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled},
+{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled},
-{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled},
+{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled},
-{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled},
+{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled},

Quick definition check:

#!/bin/bash
set -euo pipefail
syms=(
  run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled
  run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled
)
missing=0
for s in "${syms[@]}"; do
  if ! rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${s}[[:space:]]*\\(" ; then
    echo "Missing definition for: $s"
    missing=1
  fi
done
exit $missing

Also applies to: 4842-4843

cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h (1)

569-576: Fix KV head-indexing for GQA/MQA in Gmem_tile_q_k_v

The ctor currently sets idx = binfo.bidh for all tensors. That is incorrect when numKvHeads < numHeads (GQA/MQA) for K/V: the K/V head index must be kv_head_id = bidh / headsPerKv. Mirror the mapping used elsewhere.

Apply this patch:

-        // The row offset in the batched GEMM, including the sequence offset.
-        int64_t row_offset = (int64_t) (row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_;
-        // Add the head index.
-        int64_t idx = binfo.bidh;
+        // The row offset in the batched GEMM, including the sequence offset.
+        int64_t row_offset = (int64_t) (row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_;
+        // Add the head index (map Q head -> KV head for GQA/MQA).
+        int64_t idx;
+        if (qkv_offset == 0)
+        {
+            // Q tensor: one head per Q head.
+            idx = binfo.bidh;
+        }
+        else
+        {
+            // K or V tensor: map Q head id to KV head id.
+            int const num_heads = params.h;
+            int const num_kv_heads = max(1, params.h_kv);
+            int const heads_per_kv = max(1, num_heads / num_kv_heads);
+            int const kv_head_id = binfo.bidh / heads_per_kv;
+            idx = kv_head_id;
+        }
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)

944-962: Quantize kernel launch: guard dims and grid before launch

The quantize kernel is instantiated for QK_NOPE=128, QK_ROPE=64, V=128. Add runtime checks to fail clearly if params meta dims differ, and validate grid dims against device limits (prior suggestion).

Apply this at the top of invokeMLAContextFp8Quantize():

     TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing separate qkv to FP8");

     if (params.acc_q_len > 0)
     {
         constexpr int threads_per_block = 384;
         dim3 grid(int(tensorrt_llm::common::divUp(total_kv_len, 48)), 1, params.head_num);
+        // Sanity check the expected head dims for this specialization.
+        TLLM_CHECK_WITH_INFO(params.meta.qk_nope_head_dim == 128
+                && params.meta.qk_rope_head_dim == 64
+                && params.meta.v_head_dim == 128,
+            "MLA FP8 quantization currently specialized for (qk_nope, qk_rope, v) = (128, 64, 128); got (%d, %d, %d).",
+            params.meta.qk_nope_head_dim, params.meta.qk_rope_head_dim, params.meta.v_head_dim);
+        // Validate grid dimensions to avoid exceeding device limits.
+        int maxGridDimX = 0, maxGridDimZ = 0;
+        cudaDeviceGetAttribute(&maxGridDimX, cudaDevAttrMaxGridDimX, 0);
+        cudaDeviceGetAttribute(&maxGridDimZ, cudaDevAttrMaxGridDimZ, 0);
+        TLLM_CHECK_WITH_INFO(grid.x <= maxGridDimX && grid.z <= maxGridDimZ,
+            "Quantize kernel grid dims exceed device limits: (%d,%d,%d) vs max (%d,*,%d).",
+            grid.x, grid.y, grid.z, maxGridDimX, maxGridDimZ);
cpp/kernels/fmha_v2/fmha_test.py (2)

168-169: Include SM121 in FP8 context MLA gating

Allow FP8 MLA tests to run on SM121 as well. This avoids false skips on 12.1 parts.

Apply this diff:

-    if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
-        pytest.skip("FP8 MLAs are only supported on sm120 currently.")
+    if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in (120, 121):
+        pytest.skip("FP8 MLAs are only supported on sm120/sm121 currently.")

213-214: Include SM121 in FP8 generation MLA gating

Mirror the context change to avoid false skips on 12.1.

Apply this diff:

-    if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
-        pytest.skip("FP8 MLAs are only supported on sm120 currently.")
+    if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in (120, 121):
+        pytest.skip("FP8 MLAs are only supported on sm120/sm121 currently.")
cpp/tensorrt_llm/thop/attentionOp.cpp (1)

168-185: Also validate K/V dtype and device to match Q (carry-over from prior review)

The context-MLA path still lacks dtype/device checks for k and v. This can cause UB when kernels reinterpret memory.

Apply this diff to add the missing validations:

                 TORCH_CHECK(k.has_value());
                 TORCH_CHECK(v.has_value());
                 TORCH_CHECK(k->dim() == 2);
                 TORCH_CHECK(v->dim() == 2);
                 TORCH_CHECK(k->strides()[1] == 1);
                 TORCH_CHECK(v->strides()[1] == 1);
+                TORCH_CHECK(k->scalar_type() == qkv_or_q.scalar_type(), "K dtype must match Q dtype");
+                TORCH_CHECK(v->scalar_type() == qkv_or_q.scalar_type(), "V dtype must match Q dtype");
+                TORCH_CHECK(k->scalar_type() == v->scalar_type(), "K and V dtypes must match");
+                TORCH_CHECK(k->device() == qkv_or_q.device() && v->device() == qkv_or_q.device(),
+                    "K/V must be on the same device as Q");
🧹 Nitpick comments (24)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h (1)

1860-1861: SM90 softmax rows: annotate mask/layout and confirm shared wrapper is intended

  • Add inline comments for mask/layout to avoid magic numbers.
  • Both non-causal and causal rows reference the same run wrapper. Verify the wrapper branches on mAttentionMaskType correctly for CAUSAL vs PADDING.
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (5)

149-161: Remove duplicate early-return check for merge_op.

merge_op_val == 0 is checked twice back-to-back. The second check is redundant and can be removed to reduce instruction count and warp divergence.

Apply this diff:

     int64_t merge_op_val = merge_op[batch_idx];
     if (merge_op_val == 0)
     {
         return; // skip this batch
     }

-    size_t const head_dim_vec_idx = (threadIdx.x % KT::kVecPerHead);
-    size_t const head_dim_idx = head_dim_vec_idx * KT::kElemPerThread;
-
-    if (merge_op_val == 0)
-    {
-        return; // skip this batch
-    }
+    size_t const head_dim_vec_idx = (threadIdx.x % KT::kVecPerHead);
+    size_t const head_dim_idx = head_dim_vec_idx * KT::kElemPerThread;

162-185: Use 64-bit/size_t for global offsets to avoid potential overflow on long sequences.

Offsets derived from cu_q_seq_len (int64_t) are narrowed to int and used in pointer indexing. For large concatenated sequences, this can overflow 32-bit. Prefer int64_t/size_t for offsets.

If your workloads guarantee <2^31 elements, this is low risk; otherwise, consider the change below.

-    int const curr_q_len = static_cast<int>(cu_q_seq_len[batch_idx + 1] - cu_q_seq_len[batch_idx]);
-    int const global_q_offset = cu_q_seq_len[batch_idx];
+    int64_t const curr_q_len = cu_q_seq_len[batch_idx + 1] - cu_q_seq_len[batch_idx];
+    int64_t const global_q_offset = cu_q_seq_len[batch_idx];

     for (int local_token_idx = (threadIdx.x / KT::kVecPerHead) + blockIdx.x * KT::kTokenPerBlock;
          local_token_idx < curr_q_len; local_token_idx += gridDim.x * KT::kTokenPerBlock)
     {
         // load softmax stat
-        int const global_softmax_stats_offset = (global_q_offset + local_token_idx) * num_heads + head_idx;
+        size_t const global_softmax_stats_offset
+            = static_cast<size_t>(global_q_offset + local_token_idx) * static_cast<size_t>(num_heads)
+            + static_cast<size_t>(head_idx);
         float2 curr_stats = curr_softmax_stats[global_softmax_stats_offset];
         float2 pre_stats = pre_softmax_stats[global_softmax_stats_offset];

         // load attn
         typename KT::VecReader pre_attn_reader{};
         typename KT::VecReader curr_attn_reader{};
         typename KT::VecReader merged_attn_reader{};

-        int const global_attn_offset
-            = (global_q_offset + local_token_idx) * num_heads * head_size + head_idx * head_size;
+        size_t const global_attn_offset = static_cast<size_t>(global_q_offset + local_token_idx)
+            * static_cast<size_t>(num_heads) * static_cast<size_t>(head_size)
+            + static_cast<size_t>(head_idx) * static_cast<size_t>(head_size);

201-203: Use device fast-math for exp to avoid double promotion and improve throughput.

In device code, std::exp may promote to double and is slower. Use __expf (or expf) with float inputs for better performance and consistency with fmaxf.

-            float pre_shift = std::exp(pre_stats.x - merged_stats.x);
-            float curr_shift = std::exp(curr_stats.x - merged_stats.x);
+            float pre_shift = __expf(pre_stats.x - merged_stats.x);
+            float curr_shift = __expf(curr_stats.x - merged_stats.x);

Note: If you rely on precise IEEE semantics, prefer expf; __expf trades a bit of precision for speed.


246-261: Type kvSrc as TCache and make it const for correctness and clarity.

The cache storage element type is TCache, which can differ from T (e.g., FP8). Casting to T* is confusing and may mislead future refactors. Use TCache const* and keep the vectorized load const-correct.

-            auto* kvSrc = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
+            auto const* kvSrc
+                = reinterpret_cast<TCache const*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
             // head_idx === 0
             auto kvBlockIdx
                 = kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, KT::kVecPerHead, static_cast<int>(head_dim_vec_idx));
-            auto ld_data = (reinterpret_cast<typename KT::VecT*>(kvSrc))[kvBlockIdx];
+            auto ld_data = (reinterpret_cast<typename KT::VecT const*>(kvSrc))[kvBlockIdx];

138-221: Minor: consider aligning math intrinsics and documenting head_size assumptions.

  • You use fmaxf and std::exp together; after switching to __expf/expf, math intrinsics will be consistent.
  • Head size is hardcoded in traits (128) and checked at runtime; a brief comment in the kernel noting that KT::kHeadSize must be 128 would help future maintainers.
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp (1)

1-4: Header-check exception for LFS pointer files

This .cpp is a Git LFS pointer, so adding a copyright header would break the pointer format. Ensure the header/lint rules exclude cubin pointer files under cubin/ to avoid false positives in CI.

If needed, I can propose an exclude pattern update for your header-check tooling (e.g., regex to skip paths matching **/cubin/*.cubin.cpp).

cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp (1)

155-158: SEPARATE_Q_K_V runtime path: ensure q/k/v pointers are present

When using the separate Q/K/V layout, qPtr/kPtr/vPtr must be non-null. Add a defensive check to fail fast if callers forget to set them.

Apply this diff locally within FmhaDispatcher::run:

-        else if (mFixedParams.attentionInputLayout == AttentionInputLayout::SEPARATE_Q_K_V)
+        else if (mFixedParams.attentionInputLayout == AttentionInputLayout::SEPARATE_Q_K_V)
         {
             qkvLayout = kernels::QkvLayout::SeparateQkv;
+            TLLM_CHECK_WITH_INFO(runnerParams.qPtr && runnerParams.kPtr && runnerParams.vPtr,
+                "SEPARATE_Q_K_V layout requires non-null qPtr/kPtr/vPtr.");
         }
cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h (1)

634-640: Preserve move_col(steps) in all Q/K/V tile variants

To keep the API symmetric with rewind_col(int steps) and consistent across sibling tiles, add a default steps parameter to every move_col that currently lacks it.

Affected locations in
cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h:

  • Gmem_tile_q_k_v::move_col (around line 635)
  • Gmem_tile_contiguous_kv::move_col (around line 1190)

Proposed patch:

--- a/cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h
+++ b/cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h
@@ -633,7 +633,9 @@ struct Gmem_tile_q_k_v {
-    inline __device__ void move_col()
+    inline __device__ void move_col(int const steps = 1)
     {
-        q_k_v_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8);
-        // Update col_in_bytes_ to ensure load predicates work
-        col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG;
+        q_k_v_ptr_    += (int64_t) COLS * (BITS_PER_ELEMENT / 8) * steps;
+        // Update col_in_bytes_ to ensure load predicates work
+        col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG       * steps;
     }
@@ -1188,7 +1188,9 @@ struct Gmem_tile_contiguous_kv {
-    inline __device__ void move_col()
+    inline __device__ void move_col(int const steps = 1)
     {
-        kv_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8);
-        // Update col_in_bytes_ to ensure load predicates work
-        col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG;
+        kv_ptr_       += (int64_t) COLS * (BITS_PER_ELEMENT / 8) * steps;
+        // Update col_in_bytes_ to ensure load predicates work
+        col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG       * steps;
     }
tensorrt_llm/_torch/attention_backend/trtllm.py (1)

205-206: Docstring line exceeds 120 chars

Break the host_total_kv_lens docstring line to satisfy the style limit flagged by static analysis.

-            host_total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU.
+            host_total_kv_lens (torch.Tensor): The tensor to store the total KV lens for
+                context requests and generation requests, with shape (2) on CPU.
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)

820-921: QuantizeCopyInputToFp8Kernel: consider generalizing grid.x token grouping

grid.x uses divUp(total_kv_len, 48). 48 is the LCM of tokens-per-block for QK and V for BF16/FP16, but not for FP32; it still works but may underutilize SMs on some shapes. Consider computing grid.x based on the larger of QK_TOKENS_PER_BLOCK and V_TOKENS_PER_BLOCK (or their LCM) per specialization to balance occupancy and iteration count.

tests/unittest/_torch/test_attention_mla.py (2)

671-673: Broaden enable_flash_mla gating beyond (9, 0)

Flash MLA appears to be available on newer arch as well. Consider enabling on SM100/SM120 too to avoid false negatives on capable hardware.

-                enable_flash_mla=torch.cuda.get_device_capability() == (9, 0),
+                enable_flash_mla=torch.cuda.get_device_capability() in {(9, 0), (10, 0), (12, 0)},

750-757: Line length in diagnostic prints exceeds 120

Split long f-strings for readability and to satisfy linters.

-            print(
-                f"{backend_name} output mean: {result.abs().mean().item()}, max: {result.abs().max().item()}"
-            )
-            print(
-                f"Reference output mean: {ref_result.abs().mean().item()}, max: {ref_result.abs().max().item()}"
-            )
-            print(
-                f"Difference mean: {(result - ref_result).abs().mean().item()}, max: {(result - ref_result).abs().max().item()}"
-            )
+            print(f"{backend_name} output mean: {result.abs().mean().item()}, "
+                  f"max: {result.abs().max().item()}")
+            print(f"Reference output mean: {ref_result.abs().mean().item()}, "
+                  f"max: {ref_result.abs().max().item()}")
+            print(f"Difference mean: {(result - ref_result).abs().mean().item()}, "
+                  f"max: {(result - ref_result).abs().max().item()}")
cpp/kernels/fmha_v2/setup.py (2)

3649-3660: Clarify gating intent for MLA vs normal return-softmax cases

The new skip_mla_combination fixes earlier over-enumeration. Minor readability tweak: make MLA layout intent explicit and avoid double-negatives.

Apply this diff:

-        # for normal attention, we only need contiguous kv as input layout when returning softmax.
-        skip_combination = return_softmax and input_layout != InputLayout.CONTIGUOUS_Q_KV
-        # for context mla, we need separate qkv as input layout when returning softmax.
-        skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V
+        # For normal attention w/ return_softmax, only CONTIGUOUS_Q_KV is needed.
+        skip_combination = return_softmax and input_layout != InputLayout.CONTIGUOUS_Q_KV
+        # For context MLA w/ return_softmax, only SEPARATE_Q_K_V is meaningful.
+        is_mla_layout = input_layout == InputLayout.SEPARATE_Q_K_V
+        skip_mla_combination = return_softmax and not is_mla_layout

6367-6391: Fix E712 comparisons to True/False in kernel filter (ruff)

Pythonic conditionals improve readability and satisfy lints. Replace explicit True/False comparisons.

Apply this diff:

-                  # Deepseek MLA (generation 576/512 paged)
-                  or (kspec.sm            in [90, 100, 120]
+                  # Deepseek MLA (generation 576/512 paged)
+                  or (kspec.sm            in [90, 100, 120]
                   and kspec.dtype         in ['bf16', 'e4m3_fp32']
                   and kspec.head_size     == 576
                   and kspec.head_size_v   == 512
                   and kspec.input_layout == InputLayout.Q_PAGED_KV
                   and kspec.sage_block_sizes is None
                   and kspec.version       == 2
-                  and kspec.cross_mha     == False
-                  and kspec.flash_attention == True
-                  and kspec.warp_specialization == False
-                  and kspec.tiled == True)
+                  and not kspec.cross_mha
+                  and kspec.flash_attention
+                  and not kspec.warp_specialization
+                  and kspec.tiled)
@@
-                  # Deepseek MLA (context 192/128 separate-q-k-v)
-                  or (kspec.sm            in [90, 100, 120]
-                  and kspec.dtype         in ['bf16', 'e4m3_fp32']
+                  # Deepseek MLA (context 192/128 separate-q-k-v)
+                  or (kspec.sm            in [90, 100, 120]
+                  and kspec.dtype         in ['bf16', 'e4m3_fp32']
                   and kspec.head_size     == 192
                   and kspec.head_size_v   == 128
                   and kspec.input_layout == InputLayout.SEPARATE_Q_K_V
                   and kspec.sage_block_sizes is None
                   and kspec.version       == 2
-                  and kspec.cross_mha     == False
-                  and kspec.flash_attention == True
-                  and ((kspec.warp_specialization == True and kspec.alibi == False)   # sm90
-                    or (kspec.warp_specialization == False and kspec.tiled == True))  # non-sm90
-                  and kspec.enable_attn_logit_softcapping == False)
+                  and not kspec.cross_mha
+                  and kspec.flash_attention
+                  and ((kspec.warp_specialization and not kspec.alibi)      # sm90
+                    or (not kspec.warp_specialization and kspec.tiled))     # non-sm90
+                  and not kspec.enable_attn_logit_softcapping)
cpp/tensorrt_llm/common/attentionOp.cpp (1)

2628-2639: Using SEPARATE_Q_K_V and headSizeQkNope for context MLA is appropriate

Forcing AttentionInputLayout::SEPARATE_Q_K_V for context MLA and plumbed headSizeQkNope matches the kernel-param needs introduced in fmha runners. The headSizeV note is documented; consider an assertion or static check in debug builds if future models diverge.

cpp/tensorrt_llm/kernels/mlaKernels.h (1)

54-104: q/k mutability justified; consider typing FP8 quant buffers

You made q_buf and k_buf mutable because RoPE modifies them in-place; v_buf is const, which is correct. Consider typing the FP8 quant buffers as __nv_fp8_e4m3* instead of void* to avoid repeated casts and improve readability across declarations and call sites.

If changing the header type is too intrusive now, at least add brief comments in callers where casts happen to reduce future confusion.

cpp/tensorrt_llm/thop/attentionOp.h (1)

37-61: Minor consistency: prefer one optional type

This header mixes std::optionaltorch::Tensor and torch::optionaltorch::Tensor. Both work, but pick one style for consistency (project convention leans towards std::optional in C++17 code).

cpp/tensorrt_llm/thop/attentionOp.cpp (6)

67-85: Run() signature explosion: consider bundling args to improve safety and readability

The expanded run signature is getting unwieldy and error-prone. Consider introducing a small POD/struct (e.g., RunArgs) to group the many tensors and scalars. This reduces callsite mistakes, eases evolution, and matches the pattern you already use for enqueue params.


668-681: Workspace device selection is correct, but add capacity check when external buffer is provided

If the provided workspace tensor’s dtype/device mismatches q’s device or is not byte, implicit assumptions could break. Consider validating dtype and device.

Apply this diff:

     if (workspace_.has_value())
     {
         if (workspace_.value().numel() < workspace_size)
         {
             TLLM_LOG_WARNING("Attention workspace size is not enough, increase the size from %ld bytes to %ld bytes",
                 workspace_.value().numel(), workspace_size);
             workspace_.value().resize_({workspace_size});
         }
-        workspace = workspace_.value();
+        TORCH_CHECK(workspace_.value().dtype() == torch::kByte,
+            "Workspace tensor must have dtype=Byte");
+        TORCH_CHECK(workspace_.value().device() == qkv_or_q.device(),
+            "Workspace tensor must be on same device as inputs");
+        workspace = workspace_.value();
     }
     else
     {
-        workspace = torch::empty({workspace_size}, torch::dtype(torch::kByte).device(qkv_or_q.device()));
+        workspace = torch::empty({workspace_size}, torch::dtype(torch::kByte).device(qkv_or_q.device()));
     }

469-487: Top-level guards on fused vs separate inputs make sense, but add explicit K/V dtype/device checks early

You validate presence of K/V when needed. For earlier/faster failure, mirror the dtype/device checks here (in addition to Runner::run), so errors are caught before dispatch/templating.

Would you like me to draft a minimal guard block here?


285-292: Validate softmax_stats_tensor dtype/shape before casting to float2*

You correctly enforce float32 dtype, but since it’s reinterpreted as float2, also assert that numel is even and inner stride aligns to sizeof(float).

Proposed snippet (contextual):

TORCH_CHECK(softmax_stats_tensor->numel() % 2 == 0, "softmax_stats_tensor must have even number of floats for float2");

1-16: Update copyright year

Per coding guidelines, use the current year. Update 1993-2024 to 1993-2025.

Apply this diff:

- * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &

683-712: Request: add targeted tests for separate-QKV context MLA across all paths

Given the behavioral changes (default/KV-reuse/chunked), please add tests that:

  • validate K/V dtype/device mismatch raises,
  • verify total_kv_len is honored in enqueue,
  • cover both SM90 and SM120 config gates (feature flags permitting).

I can scaffold unit tests for test_attention_mla.py covering these scenarios. Want me to draft them?

@zhhuang-nv
Copy link
Collaborator Author

I approve for the DS modeling file.

We need perf data before merge.

@litaotju @kaiyux Benchmark result attached in the description. Please take a look.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15632 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11769 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@kaiyux kaiyux merged commit 7e135d2 into NVIDIA:main Aug 19, 2025
9 checks passed
@zhhuang-nv zhhuang-nv deleted the mla-separate-qkv branch October 20, 2025 09:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.