-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Use Separate QKV Input Layout for Context MLA #6538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughRefactors MLA/context attention to use separate Q, K, V buffers (including per-tensor FP8 quant buffers), removes MLA paged/chunked KV-cache setter kernels and helpers, introduces AttentionInputLayout::SEPARATE_Q_K_V across FMHA runners/kernels, updates C++/Python bindings and metadata to carry total_kv_len / host_total_kv_lens, and adapts tests and kernel selection/tiles accordingly. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Python caller
participant PyLayer as Py MLA/attention layer
participant Backend as TrtllmAttentionWrapper
participant CppOp as AttentionOp (C++)
participant FMHA as FMHA runner / kernels
User->>PyLayer: forward(q, k?, v?, ...)
PyLayer->>Backend: plan/run(qkv_or_q, k?, v?, host_total_kv_lens, ...)
Backend->>CppOp: run(..., qkv_or_q, k?, v?, total_kv_len, ...)
CppOp->>FMHA: enqueueContext(q_ptr, k_ptr, v_ptr, quant_q/k/v buffers?, headSizeQkNope, layout=SEPARATE_Q_K_V)
FMHA->>FMHA: prepare kernel params (strides, dtypeKv, head dims) and launch kernels
FMHA-->>CppOp: results (output, optional softmax stats)
CppOp-->>Backend: return outputs
Backend-->>PyLayer: return outputs
PyLayer-->>User: final output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🔭 Outside diff range comments (1)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
1-3: Update copyright year to 2025The copyright header should include the current year (2025) according to coding guidelines.
/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. *
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
681-681: Remove commented debug print statement.This appears to be leftover debugging code.
- # print("prepare TrtllmAttentionMetadata") extra_attrs = get_model_extra_attrs()
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (20)
cpp/tensorrt_llm/common/attentionOp.cpp(7 hunks)cpp/tensorrt_llm/common/attentionOp.h(3 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp(2 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h(4 hunks)cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp(5 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu(1 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh(0 hunks)cpp/tensorrt_llm/kernels/mlaKernels.cu(4 hunks)cpp/tensorrt_llm/kernels/mlaKernels.h(1 hunks)cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h(1 hunks)cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h(1 hunks)cpp/tensorrt_llm/thop/attentionOp.cpp(14 hunks)cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp(0 hunks)cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu(0 hunks)cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu(1 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py(13 hunks)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(4 hunks)tensorrt_llm/_torch/metadata.py(1 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(0 hunks)tensorrt_llm/_torch/modules/attention.py(6 hunks)
💤 Files with no reviewable changes (4)
- tensorrt_llm/_torch/models/modeling_deepseekv3.py
- cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
- cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
- cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/metadata.pytensorrt_llm/_torch/custom_ops/torch_custom_ops.pytensorrt_llm/_torch/modules/attention.pytensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. The block should be prepended to the top of all files, including .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
tensorrt_llm/_torch/metadata.pycpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.hcpp/tensorrt_llm/kernels/fmhaDispatcher.cppcpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.hcpp/tensorrt_llm/common/attentionOp.htensorrt_llm/_torch/custom_ops/torch_custom_ops.pycpp/tensorrt_llm/kernels/mlaChunkedPrefill.cucpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpptensorrt_llm/_torch/modules/attention.pycpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.hcpp/tensorrt_llm/kernels/mlaKernels.cucpp/tests/unit_tests/kernels/mlaPreprocessTest.cucpp/tensorrt_llm/kernels/mlaKernels.hcpp/tensorrt_llm/common/attentionOp.cpptensorrt_llm/_torch/attention_backend/trtllm.pycpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camelcase prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope and function-scope magic-number/literal constants are uppercase snakecas...
Files:
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.hcpp/tensorrt_llm/kernels/fmhaDispatcher.cppcpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.hcpp/tensorrt_llm/common/attentionOp.hcpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cppcpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.hcpp/tensorrt_llm/kernels/mlaKernels.hcpp/tensorrt_llm/common/attentionOp.cppcpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{h,hpp}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.
Files:
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.hcpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.hcpp/tensorrt_llm/common/attentionOp.hcpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.hcpp/tensorrt_llm/kernels/mlaKernels.h
🧠 Learnings (3)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T04:09:12.904Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.
Applied to files:
cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : all tensorrt-llm open source software code should contain...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T04:09:12.904Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. The block should be prepended to the top of all files, including .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Applied to files:
cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Applied to files:
cpp/tests/unit_tests/kernels/mlaPreprocessTest.cucpp/tensorrt_llm/common/attentionOp.cpptensorrt_llm/_torch/attention_backend/trtllm.py
🧬 Code Graph Analysis (2)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (1)
cpp/tensorrt_llm/kernels/multiHeadAttentionCommon.h (2)
get_size_in_bytes(81-95)get_size_in_bytes(99-102)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (4)
void(75-104)void(107-136)void(141-224)void(229-295)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py
201-201: Line too long (148 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (48)
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h (1)
239-240: LGTM: Clear addition of MLA-specific parameterThe addition of
mHeadDimQkNopewith appropriate documentation clearly indicates its purpose for MLA context handling. The parameter naming follows the established camelCase convention and the comment provides sufficient context.tensorrt_llm/_torch/metadata.py (1)
31-31: LGTM: Simplified type annotation aligns with refactoringThe change from
Optional[List[int]]toOptional[int]simplifies the interface and is consistent with the broader refactoring to handle separate Q, K, V tensors. The field name and default value remain appropriate.cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h (1)
361-365: LGTM: Correct stride calculation for non-contiguous V tensorThe new conditional branch properly handles the stride calculation for the V tensor in context MLA when using separate QKV layout. The calculation
options.mNumHeadsKv * (options.mHeadDimQkNope + options.mHeadDimV)correctly accounts for the non-contiguous layout by including both dimensions. The comment clearly explains the context-specific requirement.cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp (4)
39-42: LGTM: Proper mapping for new SEPARATE_Q_K_V layoutThe new case correctly maps
AttentionInputLayout::SEPARATE_Q_K_VtoQkvLayout::SeparateQkv, maintaining consistency with the established pattern for layout conversions.
155-158: LGTM: Consistent layout handling in dispatcherThe additional case for
SEPARATE_Q_K_Vlayout is properly handled, setting theqkvLayouttoSeparateQkvas expected for the new attention input layout.
172-173: LGTM: Explicit assignment of separate K and V pointersThe assignment of
kPtrandvPtrfromrunnerParamsenables proper handling of separate K and V tensors, replacing the previousnullptrassignments and supporting the new separate QKV input layout.
191-191: LGTM: Proper propagation of MLA-specific parameterThe assignment of
mHeadDimQkNopefrommFixedParams.headSizeQkNopecorrectly propagates the MLA-specific head dimension parameter to the runner parameters.cpp/tensorrt_llm/common/attentionOp.h (3)
96-96: LGTM: Added necessary KV length trackingThe addition of
total_kv_lentoEnqueueParamsprovides essential tracking for total key-value length, supporting the new separate QKV input layout functionality.
130-132: LGTM: Well-documented separate K and V pointer supportThe addition of optional
k_ptrandv_ptrmembers is properly documented, clearly indicating their use for separate QKV input specifically for context MLA. The optional nature and specific use case are well communicated.
184-185: LGTM: Complete debug output coverageThe addition of the new K and V pointers to the debug string output ensures comprehensive logging and troubleshooting capabilities for the new separate QKV functionality.
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (2)
181-202: LGTM! Well-structured separate QKV layout implementation.The implementation correctly handles the new
SEPARATE_Q_K_Vlayout by setting separate pointers and calculating appropriate strides. The conditional logic for V tensor stride based onheadSizeQkNopeproperly handles both contiguous and non-contiguous V tensor layouts, which aligns with the MLA context requirements.
633-638: LGTM! Consistent TMA descriptor setup for separate QKV.The implementation correctly sets the K and V pointers from kernel parameters for the
SEPARATE_Q_K_Vlayout, maintaining consistency with the established pattern used for other layout types.tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
896-896: LGTM! Consistent addition of total_kv_lens parameter.The
total_kv_lensparameter is correctly added to the function signature, passed through to the underlyingattention_inplaceoperation, and included in the fake registration. This supports the transition to explicit KV length tracking.Also applies to: 980-980, 1010-1010
1060-1060: LGTM! Parameter ordering alignment.The reordering of
softmax_stats_tensorin the fake registration ensures consistency with the function signature and underlying operation interface.cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (1)
353-353: LGTM! Cleanup aligns with PR objectives.The removal of MLA chunked KV cache setting kernels aligns with the PR's goal of eliminating these specialized operations in favor of separate QKV input handling. The remaining functionality for merging attention and loading chunked KV is preserved as expected.
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h (4)
84-87: LGTM: Well-documented enum extension for separate QKV layout.The addition of
SEPARATE_Q_K_Vto theAttentionInputLayoutenum is well-implemented with clear documentation describing the layout structure and its specific use case for context MLA.
121-122: LGTM: Appropriate addition of MLA-specific parameter.The
headSizeQkNopemember is correctly added with proper documentation indicating its MLA-specific usage for the Q/K non-RoPE part dimension.
174-174: LGTM: Consistent string representation for new enum value.The string conversion method correctly handles the new
SEPARATE_Q_K_Vlayout with appropriate naming that matches the enum value.
264-267: LGTM: Logical addition of separate K and V pointers.The new
kPtrandvPtrmembers are appropriately added to support the separate QKV input layout, complementing the existingqPtrandkvPtrpointers with clear documentation.cpp/tensorrt_llm/kernels/mlaKernels.h (1)
55-57: LGTM: Clean refactoring to separate QKV buffers.The modification of
MlaParamsto support separate Q, K, V buffers is well-implemented:
- Clear documentation of the new buffer layouts and dimensions
- Logical separation of concerns with dedicated buffers for context MLA
- Consistent naming conventions following the existing pattern
cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1)
119-119: LGTM: Appropriate test cleanup for removed functionality.The modification reflects the removal of test infrastructure for the deleted paged KV cache setting kernel. The remaining test coverage for the
loadPagedKVfunctionality is still intact and appropriate.cpp/tensorrt_llm/common/attentionOp.cpp (4)
1629-1649: Clear documentation for the new separate QKV input layoutThe extended comment properly documents the new separate QKV input layout for context MLA, including all required pointers and parameters. This improves code maintainability.
1663-1676: Correct implementation of separate QKV buffer pointers for MLAThe conditional assignment of FMHA device buffer pointers properly handles both MLA (separate Q, K, V) and non-MLA (packed QKV or Q-only) cases, maintaining backward compatibility.
1687-1688: Good simplification of paged KV cache assignmentThe direct assignment eliminates unnecessary conditional logic, improving code clarity and aligning with the PR's goal of removing extra operations.
2537-2547: Proper initialization of MLA-specific FMHA parametersThe initialization correctly sets the attention input layout to
SEPARATE_Q_K_Vand includes the newheadSizeQkNopeparameter when MLA is enabled, completing the support for separate QKV input layout.cpp/tensorrt_llm/kernels/mlaKernels.cu (4)
187-189: LGTM! Function signature properly updated for separate Q/K buffersThe change from a single
qkv_outputto separateq_ptrandk_ptraligns with the PR objective to eliminate concatenation operations in the default path.
242-246: Correct update to use separate query bufferThe change properly loads query data from the dedicated
q_ptrbuffer instead offuse_buf, maintaining the same offset calculations.
826-828: Kernel invocation correctly updated for separate K bufferThe change from
params.latent_cachetoparams.k_bufas the second parameter correctly matches the updated kernel signature that expects separate Q and K pointers.
276-282: Ignore identical Q/K indexing concernThe matching
dst_q_idxanddst_k_idxformulas are intentional: both Q and K undergo the same RoPE offset calculation, and becauseq_ptrandk_ptrpoint to separate output buffers, the identical index correctly places each tensor’s data in its own memory. No change is needed here.Likely an incorrect or invalid review comment.
tensorrt_llm/_torch/modules/attention.py (6)
944-949: Clean implementation of explicit K tensor constructionThe code correctly constructs the full key tensor by:
- Reshaping
full_k_nopeto separate head dimensions- Broadcasting
full_k_peacross all attention heads- Concatenating along the feature dimension
- Flattening to the expected shape
This explicit approach successfully replaces the previous fused QKV handling.
960-969: Correct adaptation to separate K/V tensor interfaceThe
mha.forwardcall properly passes separatefull_kandfull_vtensors, and the explicitNoneforlatent_cachewith explanatory comment helps distinguish this cached KV code path.
1015-1015: Proper handling of total_kv_lens during chunked attentionThe code correctly:
- Saves the original
total_kv_lens[0]value- Temporarily updates it for each chunk and the final attention computation
- Restores the original value after processing
This pattern ensures the attention kernels receive the correct KV lengths for each processing stage.
Also applies to: 1060-1061, 1105-1108, 1131-1131
1033-1054: Consistent K tensor construction with proper memory managementThe code follows the established pattern for constructing K tensors and includes proper cleanup of intermediate variables to free memory. The explicit tensor reshaping and concatenation correctly implements the separate Q/K/V approach.
1089-1100: Consistent implementation for uncached KV handlingThe K tensor construction for uncached KV follows the same proven pattern, ensuring consistency across all code paths in the chunked prefill implementation.
1112-1115: Final attention call correctly uses separate K/V tensorsThe last
mha.forwardcall in the chunked prefill path properly passes separatekandvtensors, maintaining consistency with the new separate QKV architecture.tensorrt_llm/_torch/attention_backend/trtllm.py (8)
26-26: LGTM! Addition of total_kv_lens tensor.The addition of
total_kv_lenstensor to track total key-value lengths is well-integrated with the existing tensor attributes.
157-157: Well-documented addition of total_kv_lens parameter.The
total_kv_lensparameter is properly added to the plan() method with clear documentation explaining its purpose and shape (2) on CPU.Also applies to: 201-201, 228-228
328-340: Correct implementation of separate QKV inputs for MLA context.The changes properly enforce separate Q, K, V inputs for MLA context mode while maintaining fused QKV for generation mode. The hidden size calculation is correctly updated to use only Q dimensions.
412-412: LGTM! Correctly passing total_kv_lens to attention operation.
603-603: Proper initialization and computation of total KV lengths.The
total_kv_lenstensor is correctly initialized with shape (2) on CPU and properly computed by summing KV lengths for context and generation requests separately.Also applies to: 728-731
1130-1130: LGTM! Propagating total_kv_lens through forward method.
1225-1225: Good use of modern Python type annotation.Using lowercase
tupleis the preferred syntax for Python 3.9+.
1267-1279: Good edge case handling for empty cached tokens.The addition of early return with properly shaped empty tensors when there are no cached tokens is an efficient optimization that avoids unnecessary kernel calls.
cpp/tensorrt_llm/thop/attentionOp.cpp (5)
66-71: Well-documented interface update for separate QKV inputs.The comment clearly explains the typical single QKV input case and the exception for context MLA using separate inputs. The addition of optional k and v tensors with total_kv_len parameter properly extends the interface.
137-186: Correct implementation of separate K, V buffer handling for MLA context.The code properly:
- Renames the variable to
qkv_or_qfor clarity- Validates K, V tensors are provided for MLA context
- Checks tensor dimensions and strides
- Correctly extracts and assigns buffer pointers
290-290: Proper propagation of new parameters through enqueue flow.The
total_kv_lenand separate K, V pointers are correctly added to the enqueue parameters structure.Also applies to: 307-308
415-416: Comprehensive update to support separate QKV inputs.The changes correctly:
- Add
total_kv_lensparameter to track KV lengths- Ensure backward compatibility by requiring fused QKV for non-MLA
- Extract and use separate total KV lengths for context and generation
Also applies to: 443-445, 617-618
653-653: Consistent updates for renamed variables and new parameters.The workspace tensor creation correctly uses the renamed
qkv_or_qvariable, and the TORCH_LIBRARY definition properly includes the newtotal_kv_lensparameter.Also applies to: 754-754
|
PR_Github #13764 [ run ] triggered by Bot |
|
/bot kill |
|
PR_Github #13796 [ kill ] triggered by Bot |
|
PR_Github #13764 [ run ] completed with state |
|
PR_Github #13796 [ kill ] completed with state |
fdfba9e to
d47541e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
1-15: Update copyright year to 2025.According to the coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header that includes the current year.
/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. *
🧹 Nitpick comments (2)
tensorrt_llm/_torch/attention_backend/trtllm.py (2)
681-681: Remove commented-out debug print statement.- # print("prepare TrtllmAttentionMetadata")
201-201: Line exceeds maximum length of 120 characters.- total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU. + total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and + generation requests, with shape (2) on CPU.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (20)
cpp/tensorrt_llm/common/attentionOp.cpp(7 hunks)cpp/tensorrt_llm/common/attentionOp.h(3 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp(2 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h(4 hunks)cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp(5 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu(1 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh(0 hunks)cpp/tensorrt_llm/kernels/mlaKernels.cu(4 hunks)cpp/tensorrt_llm/kernels/mlaKernels.h(1 hunks)cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h(1 hunks)cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h(1 hunks)cpp/tensorrt_llm/thop/attentionOp.cpp(14 hunks)cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp(0 hunks)cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu(0 hunks)cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu(1 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py(13 hunks)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py(4 hunks)tensorrt_llm/_torch/metadata.py(1 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(0 hunks)tensorrt_llm/_torch/modules/attention.py(5 hunks)
💤 Files with no reviewable changes (4)
- tensorrt_llm/_torch/models/modeling_deepseekv3.py
- cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
- cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
- cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
🚧 Files skipped from review as they are similar to previous changes (12)
- tensorrt_llm/_torch/metadata.py
- cpp/tensorrt_llm/common/attentionOp.h
- cpp/tensorrt_llm/kernels/mlaKernels.h
- cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
- cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
- cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
- tensorrt_llm/_torch/modules/attention.py
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
- cpp/tensorrt_llm/common/attentionOp.cpp
- cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
- tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cucpp/tensorrt_llm/kernels/mlaKernels.cutensorrt_llm/_torch/attention_backend/trtllm.pycpp/tensorrt_llm/thop/attentionOp.cpp
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic...
Files:
cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (2)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.
Applied to files:
cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Applied to files:
tensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py
201-201: Line too long (148 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (13)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (1)
1-360: LGTM! Removal of MLA chunked KV cache setting kernels aligns with PR objectives.The removal of
setChunkedKVKernelTraits,setChunkedKVCacheForMLAKernel, andinvokeMLASetChunkedKVsuccessfully eliminates the MLA-specific chunked KV cache setting functionality as intended. The retained functions for merging attention and loading chunked KV cache appear to serve different purposes and are correctly preserved.cpp/tensorrt_llm/kernels/mlaKernels.cu (2)
186-282: LGTM! Successful refactoring to support separate Q, K, V buffers.The modification of
applyMLARopeAndAssignQKVKernelOptContextto accept separateq_ptrandk_ptrparameters correctly implements the new SEPARATE_Q_K_V layout. The kernel properly writes Q data to the dedicated Q buffer and K data to the K buffer, eliminating the need for the removed paged KV cache setting operations.
825-829: LGTM! Clean removal of MLA paged KV cache setting functionality.The removal of
setPagedKVKernelTraits,setPagedKVCacheForMLAKernel, andinvokeMLASetPagedKVsuccessfully eliminates the MLA-specific paged KV cache setting functionality as intended. The kernel invocation correctly uses the new separate buffer interface.tensorrt_llm/_torch/attention_backend/trtllm.py (3)
26-26: LGTM! Well-integrated addition of total_kv_lens tensor.The new
total_kv_lenstensor is properly integrated throughout the attention wrapper and metadata classes. It correctly tracks the total key-value lengths for context and generation requests separately, which aligns with the new separate Q, K, V buffer design.Also applies to: 157-157, 228-228, 412-412, 728-731, 1130-1130
328-332: LGTM! Correct enforcement of separate Q, K, V inputs for MLA.The updated assertions properly enforce that MLA context phase requires separate Q, K, V inputs (
not is_fused_qkv) while generation phase still uses fused QKV. This aligns with the removal of fused QKV support in MLA context operations.
1273-1279: LGTM! Proper handling of empty tensor case in load_chunked_kv_cache_for_mla.The function correctly returns empty tensors with appropriate shapes when there are no cached tokens, preventing potential runtime errors.
cpp/tensorrt_llm/thop/attentionOp.cpp (7)
66-83: Well-designed interface update for separate QKV support.The function signature changes clearly support the new separate QKV input layout for MLA context processing. The comment effectively explains the dual usage pattern, and the parameter naming (
qkv_or_q,k,v,total_kv_len) is descriptive and intuitive.
120-135: Consistent implementation of the updated interface.The Runner::run method signature correctly matches the updated RunnerBase interface, ensuring proper virtual function override behavior.
137-210: Robust implementation of separate QKV tensor handling.The implementation correctly handles both fused and separate QKV tensor inputs with appropriate validation:
- Proper dimensional and stride checks for K/V tensors in MLA context mode
- Consistent pointer initialization and assignment
- Clear branching logic between context and generation modes
- Memory layout assumptions are properly validated
The validation checks ensure runtime safety while maintaining performance-critical assumptions about tensor layout.
290-290: Correct parameter propagation to enqueue operations.The new
total_kv_len,k_ptr, andv_ptrparameters are properly assigned to the enqueue parameter structures, ensuring the separate tensor information flows correctly through the attention pipeline.Also applies to: 307-308
412-436: Well-structured function signature update.The addition of the
total_kv_lensparameter is logically positioned with other length-related parameters, maintaining good parameter organization and readability.
443-445: Comprehensive implementation with robust validation.The implementation correctly handles the transition to separate QKV support:
- Clear validation ensures either MLA is enabled or fused QKV is used, preventing invalid configurations
- Consistent parameter renaming (
qkv_or_q) throughout the function- Proper extraction and passing of
total_kv_lensvalues for both context and generation stages- Informative error messages guide users toward correct usage
The dual-path support (context and generation) maintains the same parameter passing pattern, ensuring consistency.
Also applies to: 614-685
754-754: Consistent Torch library binding update.The addition of the
total_kv_lensparameter to the Torch library definition correctly reflects the updated C++ function signature, maintaining proper Python-to-C++ interface consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
201-201: Line exceeds maximum length limit.Line 201 is 148 characters long, exceeding the 120 character limit.
Consider breaking this line for better readability:
- total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU. + total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and + generation requests, with shape (2) on CPU.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
cpp/tensorrt_llm/kernels/mlaKernels.cu(5 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py(12 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tensorrt_llm/kernels/mlaKernels.cutensorrt_llm/_torch/attention_backend/trtllm.py
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/attention_backend/trtllm.py
🧠 Learnings (3)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : all tensorrt-llm open source software code should contain...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Applied to files:
cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.
Applied to files:
cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Applied to files:
tensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py
201-201: Line too long (148 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
cpp/tensorrt_llm/kernels/mlaKernels.cu (2)
2-2: Copyright year correctly updated.
187-282: Kernel correctly updated for separate Q/K/V layout.The changes properly implement the separation of QKV tensors:
- Function signature updated to accept separate
q_ptrandk_ptrparameters- Memory access patterns correctly adjusted for reading from and writing to separate buffers
- Kernel invocation updated to pass
params.k_bufinstead ofparams.latent_cacheAlso applies to: 826-826
tensorrt_llm/_torch/attention_backend/trtllm.py (3)
26-26: Total KV length tracking correctly implemented.The addition of
total_kv_lenstensor properly tracks the total key-value lengths for context and generation requests separately, which aligns with the PR objectives.Also applies to: 157-157, 228-228, 412-412, 603-603, 727-730, 1129-1129
328-333: MLA context mode correctly enforces separate Q/K/V inputs.The assertions properly ensure that MLA context mode uses separate inputs instead of fused QKV, aligning with the PR's goal of eliminating extra operations.
1224-1278: Load functions correctly updated to return separate tensors.The changes to
load_paged_kv_cache_for_mlaandload_chunked_kv_cache_for_mlaproperly return tuples of (compressed_kv, k_pe) instead of single tensors, supporting the new separate Q/K/V layout. The empty tensor handling for the no-cache case is also correctly implemented.
e8bc4b3 to
755c03f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
201-201: Consider breaking long documentation line for better readability.The documentation line exceeds the 120-character limit (148 characters). Consider reformatting:
- total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU. + total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests + and generation requests, with shape (2) on CPU.cpp/tensorrt_llm/thop/attentionOp.cpp (1)
443-444: Clarify the validation logic for better readability.The validation condition
is_mla_enable || is_fused_qkvwith the error message "Only fused QKV is supported for non-MLA attention now" could be confusing. Consider making the intent clearer:-TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Only fused QKV is supported for non-MLA attention now"); +TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Either MLA must be enabled or fused QKV must be used");Or restructure the logic to be more explicit about the requirements for each case.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
cpp/tensorrt_llm/kernels/mlaKernels.cu(5 hunks)cpp/tensorrt_llm/thop/attentionOp.cpp(14 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py(12 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tensorrt_llm/kernels/mlaKernels.cutensorrt_llm/_torch/attention_backend/trtllm.pycpp/tensorrt_llm/thop/attentionOp.cpp
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic...
Files:
cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (3)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : all tensorrt-llm open source software code should contain...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Applied to files:
cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-01T07:34:42.734Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.
Applied to files:
cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Applied to files:
tensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py
201-201: Line too long (148 > 120)
(E501)
🔇 Additional comments (12)
cpp/tensorrt_llm/kernels/mlaKernels.cu (4)
2-2: LGTM! Copyright header correctly updated.The copyright year range has been properly updated to include 2025, following the coding guidelines requirement for NVIDIA copyright headers.
187-187: LGTM! Kernel signature correctly updated for separate QKV layout.The function signature change from a single
qkv_outputpointer to separateq_ptrandk_ptrparameters aligns perfectly with the PR objective of using separate QKV input layout for context MLA, eliminating the need for concatenation operations.
242-281: LGTM! Memory access patterns correctly updated for separate buffers.The memory access changes properly separate query and key operations:
- Query data is now correctly read from
q_ptrinstead of the previous unified buffer- Output writes to both
q_ptrandk_ptrmaintain the same indexing logic while targeting separate buffers- Vector operations and mathematical calculations remain consistent
This change supports the elimination of concatenation operations as intended by the PR.
826-826: LGTM! Kernel invocation correctly updated for separate QKV buffers.The kernel launch parameters have been properly updated to pass separate
params.attention_input_bufandparams.k_bufarguments, matching the modified kernel signature that now expects separateq_ptrandk_ptrparameters.tensorrt_llm/_torch/attention_backend/trtllm.py (4)
26-26: LGTM! Comprehensive addition of total_kv_lens tracking.The
total_kv_lenstensor attribute has been properly added across all relevant components:
- Correctly declared in the dataclass with appropriate type annotation
- Added to method signatures with clear documentation explaining its purpose
- Properly initialized as a 2-element CPU tensor to track context and generation KV lengths separately
- Computation logic correctly sums KV lengths excluding extra tokens
- Consistently passed through all call chains
This change effectively supports the PR objective of explicit total KV length tracking instead of relying on MLA paged KV cache mechanisms.
Also applies to: 157-157, 201-201, 228-228, 412-412, 603-603, 728-730, 1129-1129
328-332: LGTM! MLA context handling correctly updated for separate QKV layout.The assertion logic properly enforces the new MLA behavior:
- Context-only mode now requires separate QKV (
assert not is_fused_qkv), supporting the PR's separate QKV input layout- Generation-only mode maintains fused QKV behavior (
assert is_fused_qkv) for consistency- Return type annotations updated to modern Python syntax with
tuple[torch.Tensor, torch.Tensor]These changes align perfectly with the PR objective of eliminating concatenation operations in the context path.
Also applies to: 1224-1224, 1266-1266
70-70: LGTM! Softmax statistics tensor properly added.The
softmax_stats_tensorattribute is correctly declared with appropriate typing and follows the established pattern for optional tensor attributes in the wrapper class.
1272-1278: LGTM! Empty tensor creation properly handled for edge case.The empty tensor creation logic correctly handles the case when
max_ctx_cached_token_len == 0:
- Proper dimensions using MLA parameters for both KV and K_PE tensors
- Maintains dtype and device consistency
- Returns appropriate tuple format matching the function signature
This ensures robust handling of edge cases in the chunked KV cache loading.
cpp/tensorrt_llm/thop/attentionOp.cpp (4)
66-83: LGTM! Well-structured function signature updates for separate QKV support.The signature changes appropriately introduce optional
kandvtensors alongside the renamedqkv_or_qparameter, and the addition oftotal_kv_lenparameter aligns with the PR objectives. The comment clearly explains the rationale for separate QKV inputs in context MLA.
162-186: Excellent validation logic for MLA context processing.The comprehensive tensor validation ensures proper input handling for separate QKV tensors:
- Appropriate checks for tensor existence, dimensions, and memory layout
- Correct handling of latent_cache for different MLA scenarios
- Proper extraction of k_ptr and v_ptr for downstream processing
The validation follows defensive programming practices and maintains code robustness.
137-139: Consistent and clear variable renaming throughout the implementation.The renaming from
qkvtoqkv_or_qaccurately reflects the dual usage pattern and improves code clarity. The addition oftotal_kv_lenparameter properly supports the new functionality requirements.Also applies to: 290-290, 445-445, 457-457, 614-614, 653-653
412-415: Proper Torch library binding updates.The binding definition correctly adds the
total_kv_lensparameter and maintains consistency with the C++ function signature. The parameter placement and naming follow the established patterns.Also applies to: 754-754
|
/bot run --disable-fail-fast |
|
PR_Github #13801 [ run ] triggered by Bot |
|
PR_Github #13801 [ run ] completed with state |
755c03f to
89cf5e1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tensorrt_llm/_torch/attention_backend/trtllm.py (2)
70-70: Consider using Optional typing for consistency.The
softmax_stats_tensorattribute should useOptional[torch.Tensor]typing to be consistent with other optional tensor attributes in the class.- softmax_stats_tensor: Optional[torch.Tensor] + softmax_stats_tensor: Optional[torch.Tensor]
201-201: Fix line length violation.This line exceeds the 120-character limit specified in the coding guidelines.
The line should be wrapped or shortened to comply with the coding standards. Consider breaking the parameter documentation across multiple lines or shortening the description.
cpp/tensorrt_llm/thop/attentionOp.cpp (1)
191-191: Remove unnecessary blank line.There's an extra blank line that should be removed for consistency with the codebase formatting.
mla_params.latent_cache = static_cast<T const*>(latent_cache->data_ptr()); -
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (19)
cpp/tensorrt_llm/common/attentionOp.cpp(7 hunks)cpp/tensorrt_llm/common/attentionOp.h(3 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp(2 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h(4 hunks)cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp(5 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu(1 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh(0 hunks)cpp/tensorrt_llm/kernels/mlaKernels.cu(5 hunks)cpp/tensorrt_llm/kernels/mlaKernels.h(1 hunks)cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h(1 hunks)cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h(1 hunks)cpp/tensorrt_llm/thop/attentionOp.cpp(14 hunks)cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp(0 hunks)cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu(0 hunks)cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu(1 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py(12 hunks)tensorrt_llm/_torch/metadata.py(1 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(0 hunks)tensorrt_llm/_torch/modules/attention.py(5 hunks)
💤 Files with no reviewable changes (4)
- tensorrt_llm/_torch/models/modeling_deepseekv3.py
- cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
- cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
- cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
🚧 Files skipped from review as they are similar to previous changes (11)
- tensorrt_llm/_torch/metadata.py
- cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
- cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
- cpp/tensorrt_llm/common/attentionOp.h
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
- cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
- tensorrt_llm/_torch/modules/attention.py
- cpp/tensorrt_llm/kernels/mlaKernels.h
- cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
- cpp/tensorrt_llm/common/attentionOp.cpp
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cucpp/tensorrt_llm/kernels/mlaKernels.cutensorrt_llm/_torch/attention_backend/trtllm.pycpp/tensorrt_llm/thop/attentionOp.cpp
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope and function-...
Files:
cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (3)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : all tensorrt-llm open source software code should contain...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-04T02:12:17.550Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Applied to files:
cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : all class templates, function templates, class template member ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-04T02:12:17.550Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : All class templates, function templates, class template member functions and class template static members shall be instantiated at least once.
Applied to files:
cpp/tensorrt_llm/kernels/mlaKernels.cu
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Applied to files:
tensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py
201-201: Line too long (148 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (31)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (2)
1-15: LGTM! Copyright header is compliant.The copyright header correctly includes the current year (2025) and follows the required NVIDIA format for TensorRT-LLM Open Source Software.
353-353: LGTM! Template instantiation is correct.The template instantiation for
invokeMLALoadChunkedKVwith FP8 cache type is properly formatted and syntactically correct.cpp/tensorrt_llm/kernels/mlaKernels.cu (6)
2-2: LGTM! Copyright header updated correctly.The copyright header properly includes the current year (2025) in compliance with TensorRT-LLM coding guidelines.
187-187: LGTM! Kernel signature updated for separate Q/K buffers.The change from a single
qkv_outputpointer to separateq_ptrandk_ptrparameters aligns with the PR objective to use separate QKV input layout, eliminating concatenation operations.
242-243: LGTM! Q buffer offset calculation is correct.The offset calculation properly computes the global token index multiplied by the head-specific stride, maintaining correct indexing for the separate Q buffer layout.
245-245: LGTM! Q buffer access is consistent.The buffer access correctly uses the separate
q_ptrwith the computed offset, maintaining consistency with the new separate buffer layout.
276-282: LGTM! Output assignments correctly use separate Q/K buffers.The index calculations and buffer writes properly handle the separate Q and K outputs, completing the transformation from fused QKV to separate buffer handling. The indexing is mathematically consistent and follows the expected memory layout.
826-826: LGTM! Function call updated to match new kernel signature.The kernel launch correctly passes
params.k_bufas the separate K buffer argument, maintaining consistency with the updated kernel signature that now accepts separate Q and K pointers.tensorrt_llm/_torch/attention_backend/trtllm.py (10)
26-26: LGTM!The addition of
total_kv_lenstensor attribute is correctly integrated into the dataclass definition and follows the established pattern of other tensor attributes.
157-157: LGTM!The parameter addition is correctly documented and follows the established parameter naming conventions.
228-228: LGTM!The assignment of
total_kv_lensparameter is correctly implemented and maintains consistency with other tensor assignments in the method.
418-418: LGTM!The
total_kv_lensparameter is correctly passed to the attention operation call, maintaining consistency with the method signature updates.
609-609: LGTM!The initialization of
total_kv_lenstensor with proper shape and device placement is correctly implemented.
733-736: LGTM!The computation of total KV lengths for context and generation requests is correctly implemented. The logic properly sums KV lengths without extra tokens and stores them in the appropriate tensor indices.
1137-1137: LGTM!The
total_kv_lensparameter is correctly passed to the wrapper's plan method, maintaining consistency throughout the call chain.
1232-1232: LGTM!The return type annotations have been updated from single tensor to tuple of two tensors, which aligns with the changes mentioned in the AI summary about returning separate K and V buffers.
Also applies to: 1274-1274
1280-1286: LGTM!The handling of empty KV cache case is correctly implemented with proper tensor creation for both compressed KV and K position encoding tensors.
357-367: MLA QKV Handling Assertions Are CorrectSearches confirm that in
tensorrt_llm/_torch/modules/attention.pyand corresponding unit tests,AttentionInputType.context_onlyalways invokes the multi-head attention (MHA) with separateq, k, vtensors, whereasgeneration_onlyuses a fusedqtensor. The new assertions intrtllm.pysimply enforce this existing behavior and align with the rest of the codebase. No changes required.cpp/tensorrt_llm/thop/attentionOp.cpp (13)
66-83: LGTM!The interface updates to support separate Q, K, V tensors and the new
total_kv_lenparameter are well-designed. The comment on line 66 clearly explains the context for using separate QKV inputs.
122-135: LGTM!The Runner template class method signature correctly matches the base class interface with the new parameters for separate K/V tensors and total KV length.
137-141: LGTM!The variable renaming from
qkvtoqkv_or_qbetter reflects the dual usage pattern, and the initialization of separate K/V pointers is correctly implemented.
290-290: LGTM!The assignment of
total_kv_lento the common enqueue parameters is correctly implemented and maintains the parameter flow through the attention pipeline.
307-308: LGTM!The assignment of separate K and V pointers to the enqueue parameters for context stage is correctly implemented, supporting the new separate tensor input layout.
412-418: LGTM!The function signature update to include separate K/V tensors and
total_kv_lensparameter is correctly implemented and maintains consistency with the interface changes.
457-457: LGTM!The variable rename from
qkvtoqkv_or_qis correctly applied throughout the function, maintaining consistency with the interface changes.
613-618: LGTM!The extraction of total KV lengths for context and generation requests from the tensor is correctly implemented, providing the necessary values for separate processing paths.
652-652: LGTM!The workspace tensor creation uses the renamed
qkv_or_qvariable consistently with other changes in the function.
662-667: LGTM!The runner calls for both context and generation stages correctly pass the new parameters including separate K/V tensors and individual total KV lengths, maintaining the proper parameter flow through the attention pipeline.
Also applies to: 678-683
753-753: LGTM!The torch library binding signature is correctly updated to include the
total_kv_lensparameter, maintaining consistency with the C++ interface changes.
162-186: Double-check MLA Context Tensor Validation ConsistencyPlease confirm that the new C++ checks match the tensor shapes and memory layouts produced by the Python LLM API:
- latent_cache
• May be null for KV‐cache reuse or chunked contexts – ensure the Python layer omits or passes None as expected.- k and v tensors
• Must be 2D (dim() == 2) with inner stride == 1 – verify that Python constructs these as contiguous[num_tokens, head_dim]buffers in context mode.
• Slicing bytoken_offset(slice(0, token_offset)) should align with how tokens are buffered upstream.If there’s any mismatch between these constraints and the Python‐side behavior (e.g. different stride layout, batched vs. flat dimensions), update either the C++ checks or the Python wrapper to ensure consistency.
443-445: Clarify MLA vs fused QKV assertion semanticsThe current check in
cpp/tensorrt_llm/thop/attentionOp.cpp(around lines 443–445) is written as an inclusive OR:TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Only fused QKV is supported for non-MLA attention now"); TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now"); auto qkv_or_q = q;This allows both
is_mla_enableandis_fused_qkvto be true simultaneously. If the intention was to make these modes mutually exclusive (i.e., MLA excludes fused QKV and vice versa), the assertion should use XOR logic:- TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, - "Only fused QKV is supported for non-MLA attention now"); + TLLM_CHECK_WITH_INFO(is_mla_enable != is_fused_qkv, + "MLA attention requires separate Q/K/V; non-MLA requires fused QKV");Otherwise, if dual-mode operation (MLA + fused QKV) is acceptable, update the comment string to clarify that both flags may be set together.
• Location: cpp/tensorrt_llm/thop/attentionOp.cpp:443–445
• No other MLA vs. fused-QKV checks were found elsewhere in the codebase.
• Theupdate_kv_cacheassertion is only present here and currently cannot be disabled.Likely an incorrect or invalid review comment.
89cf5e1 to
ab6a170
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1)
2-2: Update copyright header to include current year.The copyright header shows 2022-2024, but the coding guidelines require including the current year (2025).
Apply this diff to fix the copyright header:
- * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
157-157: Fix line length violation in docstring.The addition of
total_kv_lensparameter is correct, but the docstring on line 201 exceeds the 120 character limit.Apply this diff to fix the line length:
- total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU. + total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context + requests and generation requests, with shape (2) on CPU.Also applies to: 201-201
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (19)
cpp/tensorrt_llm/common/attentionOp.cpp(7 hunks)cpp/tensorrt_llm/common/attentionOp.h(3 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp(2 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h(4 hunks)cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp(5 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu(1 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh(0 hunks)cpp/tensorrt_llm/kernels/mlaKernels.cu(5 hunks)cpp/tensorrt_llm/kernels/mlaKernels.h(1 hunks)cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h(1 hunks)cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h(1 hunks)cpp/tensorrt_llm/thop/attentionOp.cpp(14 hunks)cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp(0 hunks)cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu(0 hunks)cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu(1 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py(12 hunks)tensorrt_llm/_torch/metadata.py(1 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py(0 hunks)tensorrt_llm/_torch/modules/attention.py(5 hunks)
💤 Files with no reviewable changes (4)
- tensorrt_llm/_torch/models/modeling_deepseekv3.py
- cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
- cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
- cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
🚧 Files skipped from review as they are similar to previous changes (11)
- tensorrt_llm/_torch/metadata.py
- cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
- cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h
- cpp/tensorrt_llm/common/attentionOp.h
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
- cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
- cpp/tensorrt_llm/kernels/mlaKernels.h
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h
- tensorrt_llm/_torch/modules/attention.py
- cpp/tensorrt_llm/common/attentionOp.cpp
- cpp/tensorrt_llm/kernels/mlaKernels.cu
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cucpp/tests/unit_tests/kernels/mlaPreprocessTest.cutensorrt_llm/_torch/attention_backend/trtllm.pycpp/tensorrt_llm/thop/attentionOp.cpp
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope and function-...
Files:
cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (1)
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Applied to files:
cpp/tests/unit_tests/kernels/mlaPreprocessTest.cutensorrt_llm/_torch/attention_backend/trtllm.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/attention_backend/trtllm.py
201-201: Line too long (148 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (26)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (2)
1-27: LGTM - Copyright header and includes are correct.The copyright header properly includes 2025 as required by the coding guidelines, and the includes are appropriate for this CUDA kernel file.
343-357: LGTM - Template instantiation macro is correctly structured.The macro properly instantiates the template functions for different data types and cache types. The modification on line 353 appears to be formatting-related and maintains the correct structure.
cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu (1)
93-507: LGTM - Test class structure is well-organized.The test class properly structures the MLA preprocessing tests. The removal of paged KV cache related functionality aligns with the PR objectives of simplifying MLA context handling. The remaining test infrastructure for load operations appears comprehensive and well-designed.
tensorrt_llm/_torch/attention_backend/trtllm.py (8)
26-26: LGTM - Addition of total_kv_lens attribute is well-integrated.The new
total_kv_lenstensor attribute provides explicit tracking of total KV lengths for context and generation requests, which aligns with the PR objectives. The attribute is properly typed and integrated into the class structure.Also applies to: 228-228
70-70: LGTM - Softmax stats tensor attribute is properly positioned.The
softmax_stats_tensorattribute is correctly typed and positioned within the dataclass structure.
357-367: LGTM - MLA input type handling correctly implements separate QKV inputs.The updated assertions correctly enforce separate Q, K, V inputs for context_only processing and fused QKV for generation_only, which aligns with the PR objectives of using separate QKV input layout for context MLA.
418-418: LGTM - Addition of total_kv_lens to attention operation call.The
total_kv_lensparameter is correctly added to the attention operation call, providing explicit KV length tracking as required by the updated kernel interface.
609-609: LGTM - Proper initialization of total_kv_lens tensor.The
total_kv_lenstensor is correctly initialized with shape (2) on CPU to store separate totals for context and generation requests.
734-736: LGTM - Correct computation of total KV lengths.The computation properly separates context and generation request totals, providing the kernels with accurate KV length information for each request type.
1232-1232: LGTM - Return type updates and empty tensor handling are correct.The explicit tuple return types for
load_paged_kv_cache_for_mlaandload_chunked_kv_cache_for_mlacorrectly reflect the separate KV and positional embedding outputs. The empty tensor handling for the edge case whenmax_ctx_cached_token_lenis 0 is good defensive programming.Also applies to: 1274-1274, 1280-1286
1137-1137: LGTM - Integration of total_kv_lens in wrapper.plan call.The
total_kv_lensparameter is correctly passed to the wrapper's plan method, ensuring the attention operation has access to the explicit KV length information.cpp/tensorrt_llm/thop/attentionOp.cpp (15)
66-83: LGTM - Clean interface extension for separate QKV inputs.The virtual method signature is properly updated to support separate K and V tensors alongside the existing QKV tensor, with clear documentation of the use case.
120-135: LGTM - Implementation signature correctly matches base class.The method signature properly implements the virtual base class interface with consistent parameter types and naming.
137-141: LGTM - Proper variable initialization for separate QKV support.The renaming to
qkv_or_qand initialization of separate K/V pointers correctly reflects the new dual-purpose interface.
162-186: LGTM - Well-structured MLA context handling with proper validation.The implementation correctly handles separate K and V tensors for MLA context mode with appropriate validation checks for tensor dimensions, strides, and presence. The latent cache handling properly accounts for different scenarios (standard context vs. KV cache reuse/chunked context).
289-289: LGTM - Correct parameter assignment.The
total_kv_lenparameter is properly assigned to the enqueue parameters structure.
306-307: LGTM - Proper pointer assignment for context stage.The K and V pointers are correctly assigned to the enqueue parameters for the context processing stage.
411-435: LGTM - Function signature properly updated for new interface.The
attention_inplacefunction signature is correctly updated to include thetotal_kv_lensparameter, maintaining consistency with the runner interface changes.
442-444: LGTM - Important validation checks for configuration consistency.The checks ensure that either MLA is enabled or fused QKV is used, and that KV cache update is always enabled. These validations prevent invalid attention configurations after the refactoring.
456-456: LGTM - Correct usage of renamed variable.The data type extraction correctly uses the renamed
qkv_or_qtensor.
651-651: LGTM - Proper device inference for workspace allocation.The workspace tensor creation correctly uses the device from the
qkv_or_qtensor.
612-612: LGTM - Correct token count calculation.The number of tokens is properly calculated from the first dimension of the input tensor.
615-616: LGTM - Proper extraction of total KV lengths.The context and generation total KV lengths are correctly extracted from the
total_kv_lenstensor using appropriate indexing.
661-666: LGTM - Context runner call properly updated with new parameters.The runner call correctly passes the renamed
qkv_or_qtensor, separatekandvtensors, and the context total KV length parameter.
677-682: LGTM - Generation runner call properly updated with new parameters.The runner call correctly passes the new parameters including the generation total KV length, maintaining consistency with the context call pattern.
752-752: LGTM - Torch library binding properly updated.The function binding correctly includes the new
total_kv_lensparameter, maintaining consistency with the updated C++ function signature.
|
/bot run --disable-fail-fast |
|
PR_Github #13920 [ run ] triggered by Bot |
|
PR_Github #13920 [ run ] completed with state |
ab6a170 to
f8459b9
Compare
|
/bot kill |
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
394cf82 to
1c74fe8
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #15632 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🔭 Outside diff range comments (6)
cpp/tensorrt_llm/common/attentionOp.cpp (2)
770-772: Under-allocation: include mFP8ContextMLA in BMM scale workspace sizesfmha_bmm1_scale_size and fmha_bmm2_scale_size only account for mFP8ContextFMHA. FP8 context MLA also consumes these buffers (you pass bmm1/bmm2 scales into params.mla_param and use them in invokeMLAContextFp8Quantize). This can under-allocate the context workspace and corrupt subsequent allocations.
Apply this fix:
- size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0; - size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0; + size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0; + size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0;
1655-1683: Null q/k/v pointers when latent_cache is absent break FP8 quantizationinvokeMLAContextFp8Quantize() requires params.mla_param->q_buf/k_buf/v_buf to be set. In the default MLA context path (latent_cache == nullptr), these pointers are never initialized before the quant call, causing TLLM_CHECK failures or undefined behavior.
Set q/k/v pointers for the default path prior to quantization:
params.mla_param->host_bmm1_scale = 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim))); + // Ensure q/k/v are set when latent_cache is absent (default separate Q/K/V input). + if (params.mla_param->q_buf == nullptr) + { + // attention_input already points to Q in context FMHA path. + params.mla_param->q_buf = const_cast<T*>(attention_input); + } + if (params.mla_param->k_buf == nullptr) + { + // k_ptr comes from THOP attention binding for separate K input. + params.mla_param->k_buf = const_cast<T*>(params.k_ptr); + } + if (params.mla_param->v_buf == nullptr) + { + // v_buf is const in MlaParams. + params.mla_param->v_buf = params.v_ptr; + } if (params.mla_param->latent_cache != nullptr) { // compute RoPE and set compressed_kv + k_pe by invokeMLARopeContext if latent_cache is not nullptr invokeMLARopeContext<T, KVCacheBuffer>(*params.mla_param, kv_cache_buffer, stream); } if (mFP8ContextMLA) { invokeMLAContextFp8Quantize(*params.mla_param, params.total_kv_len, stream); }tensorrt_llm/_torch/modules/attention.py (1)
1164-1172: Syntax error in type annotation for latent_cacheThe argument annotation is split across lines, producing invalid Python syntax.
Apply this fix:
- latent_cache: torch. - Tensor, # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size] + latent_cache: torch.Tensor, # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]cpp/tensorrt_llm/thop/attentionOp.cpp (3)
140-147: Guard against non-contiguous/invalid Q layout before taking raw pointerThe enqueue path assumes the inner-most dimension is contiguous. Add light shape/stride checks to fail-fast when Q is not laid out as expected.
Apply this diff:
- auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device()); - T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr()); + auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device()); + TORCH_CHECK(qkv_or_q.dim() >= 2, "qkv_or_q must be at least 2D: [tokens, hidden]"); + TORCH_CHECK(qkv_or_q.strides()[qkv_or_q.dim() - 1] == 1, "qkv_or_q inner-most dimension must be contiguous"); + T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
652-655: Fix potential null deref when beam_width > 1 but cache_indirection is absentThis uses cache_indirection.value() without has_value() guard. Prefer the same safe pattern you already use inside Runner::run.
Apply this diff:
- int32_t const max_attention_window_size - = beam_width == 1 ? attention_window_size : cache_indirection.value().size(2); + int32_t const max_attention_window_size = beam_width == 1 + ? attention_window_size + : (cache_indirection.has_value() ? cache_indirection.value().size(2) : attention_window_size);
576-601: Missing has_value() checks for MLA params (q_lora_rank/kv_lora_rank/...): will crash if unsetYou call .value() on several optionals without verifying has_value(). Add checks and provide a clear error message.
Apply this diff:
if (is_mla_enable) { // MLA does not support NVFP4 output yet. TLLM_CHECK(!is_fp4_out); TLLM_CHECK(host_kv_cache_pool_mapping.has_value()); int32_t const layer_num = host_kv_cache_pool_mapping.value().size(0); + TLLM_CHECK(q_lora_rank.has_value() && kv_lora_rank.has_value() && qk_nope_head_dim.has_value() + && qk_rope_head_dim.has_value() && v_head_dim.has_value(), + "MLA requires q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, and v_head_dim"); + op->mIsMLAEnabled = true; op->mMLAParams = {static_cast<int>(q_lora_rank.value()), static_cast<int>(kv_lora_rank.value()), static_cast<int>(qk_nope_head_dim.value()), static_cast<int>(qk_rope_head_dim.value()), static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq), static_cast<int>(layer_num)};
♻️ Duplicate comments (13)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h (8)
1850-1851: Document mask/layout fields for readability (avoid magic numbers)The 0/1 (mask) and 3 (layout) columns correspond to AttentionMaskType and AttentionInputLayout respectively. Inline annotations improve clarity and reduce regressions.
Apply:
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 213248, 384, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},If this header is generated, consider emitting enum names (e.g., static_cast(AttentionInputLayout::SEPARATE_Q_K_V)) in the generator.
4333-4334: SM100 NL-tiled: annotate mask/layout and confirm shared wrapper + implementation
- Add inline comments for mask/layout.
- Both rows share run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled. Confirm this is intended for both mask kinds.
- Ensure the wrapper’s definition exists (see extern at Line 1083).
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled},
4743-4744: SM120 NL-tiled BF16: annotate mask/layout and ensure wrapper definition exists
- Add inline mask/layout comments.
- Confirm the wrapper run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled (extern at Line 1224) is implemented and linked.
- Validate that sharing the same wrapper for causal and non-causal is intentional.
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled},
212-216: SM90 separate Q-K-V wrappers: ensure implementations exist to avoid link errorsThe new externs for SM90 separate Q-K-V context paths look correct, but prior verification found no corresponding definitions for these symbols, which will lead to linker failures if still missing:
- run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90 (Line 212)
- run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90 (Line 216)
Please add their definitions in the appropriate .cu/.cpp unit (consistent with other run_fmha_v2_* wrappers) and ensure runtime dispatch reaches them.
Run this to confirm definitions exist (expects at least one non-header definition per symbol):
#!/bin/bash set -euo pipefail syms=( run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90 run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90 ) for s in "${syms[@]}"; do echo ">>> Checking definition for $s" if ! rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${s}[[:space:]]*\\("; then echo "!! Missing definition for: $s" exit 1 fi done echo "All SM90 S_q_k_v wrappers found."
1083-1083: SM100 NL-tiled separate Q-K-V wrapper: definition likely missingExtern added:
- run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled
A previous review found no implementation. Add the definition and ensure both non-causal and causal registry entries route correctly to this wrapper.
#!/bin/bash set -euo pipefail sym='run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled' rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${sym}[[:space:]]*\\(" || { echo "Missing definition for: $sym" exit 1 }
1224-1224: SM120 NL-tiled BF16 separate Q-K-V wrapper: add implementationExtern added:
- run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled
This was previously flagged as missing. Provide the .cu/.cpp implementation and include it in the build.
#!/bin/bash set -euo pipefail sym='run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled' rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${sym}[[:space:]]*\\(" || { echo "Missing definition for: $sym" exit 1 }
1254-1258: SM120 E4M3 64×64 separate Q-K-V wrappers: add missing implementationsExterns added:
- run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled (Line 1254)
- run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled (Line 1258)
Previous checks found no definitions. Implement both wrappers and ensure they are linked.
#!/bin/bash set -euo pipefail syms=( run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled ) for s in "${syms[@]}"; do echo ">>> Checking definition for $s" rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${s}[[:space:]]*\\(" || { echo "Missing definition for: $s" exit 1 } done echo "All SM120 E4M3 S_q_k_v wrappers found."
4833-4834: SM120 E4M3 separate Q-K-V rows: annotate mask/layout and implement wrappers
- Add inline comments for mask/layout to avoid magic numbers.
- Ensure both E4M3 wrappers are implemented:
- run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled (Line 4833/4834)
- run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled (Line 4842/4843)
-{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled}, -{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled}, -{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled}, -{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled},Quick definition check:
#!/bin/bash set -euo pipefail syms=( run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled ) missing=0 for s in "${syms[@]}"; do if ! rg -nP -C2 -g '!**/*.h' -g '*.{cu,cpp,cc}' "^[[:space:]]*(?:__host__\\s+|__global__\\s+)?void[[:space:]]+${s}[[:space:]]*\\(" ; then echo "Missing definition for: $s" missing=1 fi done exit $missingAlso applies to: 4842-4843
cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h (1)
569-576: Fix KV head-indexing for GQA/MQA in Gmem_tile_q_k_vThe ctor currently sets idx = binfo.bidh for all tensors. That is incorrect when numKvHeads < numHeads (GQA/MQA) for K/V: the K/V head index must be kv_head_id = bidh / headsPerKv. Mirror the mapping used elsewhere.
Apply this patch:
- // The row offset in the batched GEMM, including the sequence offset. - int64_t row_offset = (int64_t) (row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_; - // Add the head index. - int64_t idx = binfo.bidh; + // The row offset in the batched GEMM, including the sequence offset. + int64_t row_offset = (int64_t) (row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_; + // Add the head index (map Q head -> KV head for GQA/MQA). + int64_t idx; + if (qkv_offset == 0) + { + // Q tensor: one head per Q head. + idx = binfo.bidh; + } + else + { + // K or V tensor: map Q head id to KV head id. + int const num_heads = params.h; + int const num_kv_heads = max(1, params.h_kv); + int const heads_per_kv = max(1, num_heads / num_kv_heads); + int const kv_head_id = binfo.bidh / heads_per_kv; + idx = kv_head_id; + }cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
944-962: Quantize kernel launch: guard dims and grid before launchThe quantize kernel is instantiated for QK_NOPE=128, QK_ROPE=64, V=128. Add runtime checks to fail clearly if params meta dims differ, and validate grid dims against device limits (prior suggestion).
Apply this at the top of invokeMLAContextFp8Quantize():
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing separate qkv to FP8"); if (params.acc_q_len > 0) { constexpr int threads_per_block = 384; dim3 grid(int(tensorrt_llm::common::divUp(total_kv_len, 48)), 1, params.head_num); + // Sanity check the expected head dims for this specialization. + TLLM_CHECK_WITH_INFO(params.meta.qk_nope_head_dim == 128 + && params.meta.qk_rope_head_dim == 64 + && params.meta.v_head_dim == 128, + "MLA FP8 quantization currently specialized for (qk_nope, qk_rope, v) = (128, 64, 128); got (%d, %d, %d).", + params.meta.qk_nope_head_dim, params.meta.qk_rope_head_dim, params.meta.v_head_dim); + // Validate grid dimensions to avoid exceeding device limits. + int maxGridDimX = 0, maxGridDimZ = 0; + cudaDeviceGetAttribute(&maxGridDimX, cudaDevAttrMaxGridDimX, 0); + cudaDeviceGetAttribute(&maxGridDimZ, cudaDevAttrMaxGridDimZ, 0); + TLLM_CHECK_WITH_INFO(grid.x <= maxGridDimX && grid.z <= maxGridDimZ, + "Quantize kernel grid dims exceed device limits: (%d,%d,%d) vs max (%d,*,%d).", + grid.x, grid.y, grid.z, maxGridDimX, maxGridDimZ);cpp/kernels/fmha_v2/fmha_test.py (2)
168-169: Include SM121 in FP8 context MLA gatingAllow FP8 MLA tests to run on SM121 as well. This avoids false skips on 12.1 parts.
Apply this diff:
- if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120: - pytest.skip("FP8 MLAs are only supported on sm120 currently.") + if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in (120, 121): + pytest.skip("FP8 MLAs are only supported on sm120/sm121 currently.")
213-214: Include SM121 in FP8 generation MLA gatingMirror the context change to avoid false skips on 12.1.
Apply this diff:
- if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120: - pytest.skip("FP8 MLAs are only supported on sm120 currently.") + if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in (120, 121): + pytest.skip("FP8 MLAs are only supported on sm120/sm121 currently.")cpp/tensorrt_llm/thop/attentionOp.cpp (1)
168-185: Also validate K/V dtype and device to match Q (carry-over from prior review)The context-MLA path still lacks dtype/device checks for k and v. This can cause UB when kernels reinterpret memory.
Apply this diff to add the missing validations:
TORCH_CHECK(k.has_value()); TORCH_CHECK(v.has_value()); TORCH_CHECK(k->dim() == 2); TORCH_CHECK(v->dim() == 2); TORCH_CHECK(k->strides()[1] == 1); TORCH_CHECK(v->strides()[1] == 1); + TORCH_CHECK(k->scalar_type() == qkv_or_q.scalar_type(), "K dtype must match Q dtype"); + TORCH_CHECK(v->scalar_type() == qkv_or_q.scalar_type(), "V dtype must match Q dtype"); + TORCH_CHECK(k->scalar_type() == v->scalar_type(), "K and V dtypes must match"); + TORCH_CHECK(k->device() == qkv_or_q.device() && v->device() == qkv_or_q.device(), + "K/V must be on the same device as Q");
🧹 Nitpick comments (24)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h (1)
1860-1861: SM90 softmax rows: annotate mask/layout and confirm shared wrapper is intended
- Add inline comments for mask/layout to avoid magic numbers.
- Both non-causal and causal rows reference the same run wrapper. Verify the wrapper branches on mAttentionMaskType correctly for CAUSAL vs PADDING.
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, /* mask=PADDING */ 0, /* layout=SEPARATE_Q_K_V */ 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, /* mask=CAUSAL */ 1, /* layout=SEPARATE_Q_K_V */ 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (5)
149-161: Remove duplicate early-return check for merge_op.
merge_op_val == 0is checked twice back-to-back. The second check is redundant and can be removed to reduce instruction count and warp divergence.Apply this diff:
int64_t merge_op_val = merge_op[batch_idx]; if (merge_op_val == 0) { return; // skip this batch } - size_t const head_dim_vec_idx = (threadIdx.x % KT::kVecPerHead); - size_t const head_dim_idx = head_dim_vec_idx * KT::kElemPerThread; - - if (merge_op_val == 0) - { - return; // skip this batch - } + size_t const head_dim_vec_idx = (threadIdx.x % KT::kVecPerHead); + size_t const head_dim_idx = head_dim_vec_idx * KT::kElemPerThread;
162-185: Use 64-bit/size_t for global offsets to avoid potential overflow on long sequences.Offsets derived from cu_q_seq_len (int64_t) are narrowed to int and used in pointer indexing. For large concatenated sequences, this can overflow 32-bit. Prefer int64_t/size_t for offsets.
If your workloads guarantee <2^31 elements, this is low risk; otherwise, consider the change below.
- int const curr_q_len = static_cast<int>(cu_q_seq_len[batch_idx + 1] - cu_q_seq_len[batch_idx]); - int const global_q_offset = cu_q_seq_len[batch_idx]; + int64_t const curr_q_len = cu_q_seq_len[batch_idx + 1] - cu_q_seq_len[batch_idx]; + int64_t const global_q_offset = cu_q_seq_len[batch_idx]; for (int local_token_idx = (threadIdx.x / KT::kVecPerHead) + blockIdx.x * KT::kTokenPerBlock; local_token_idx < curr_q_len; local_token_idx += gridDim.x * KT::kTokenPerBlock) { // load softmax stat - int const global_softmax_stats_offset = (global_q_offset + local_token_idx) * num_heads + head_idx; + size_t const global_softmax_stats_offset + = static_cast<size_t>(global_q_offset + local_token_idx) * static_cast<size_t>(num_heads) + + static_cast<size_t>(head_idx); float2 curr_stats = curr_softmax_stats[global_softmax_stats_offset]; float2 pre_stats = pre_softmax_stats[global_softmax_stats_offset]; // load attn typename KT::VecReader pre_attn_reader{}; typename KT::VecReader curr_attn_reader{}; typename KT::VecReader merged_attn_reader{}; - int const global_attn_offset - = (global_q_offset + local_token_idx) * num_heads * head_size + head_idx * head_size; + size_t const global_attn_offset = static_cast<size_t>(global_q_offset + local_token_idx) + * static_cast<size_t>(num_heads) * static_cast<size_t>(head_size) + + static_cast<size_t>(head_idx) * static_cast<size_t>(head_size);
201-203: Use device fast-math for exp to avoid double promotion and improve throughput.In device code,
std::expmay promote to double and is slower. Use__expf(orexpf) with float inputs for better performance and consistency withfmaxf.- float pre_shift = std::exp(pre_stats.x - merged_stats.x); - float curr_shift = std::exp(curr_stats.x - merged_stats.x); + float pre_shift = __expf(pre_stats.x - merged_stats.x); + float curr_shift = __expf(curr_stats.x - merged_stats.x);Note: If you rely on precise IEEE semantics, prefer
expf;__expftrades a bit of precision for speed.
246-261: Type kvSrc as TCache and make it const for correctness and clarity.The cache storage element type is
TCache, which can differ fromT(e.g., FP8). Casting toT*is confusing and may mislead future refactors. UseTCache const*and keep the vectorized load const-correct.- auto* kvSrc = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache)); + auto const* kvSrc + = reinterpret_cast<TCache const*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache)); // head_idx === 0 auto kvBlockIdx = kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, KT::kVecPerHead, static_cast<int>(head_dim_vec_idx)); - auto ld_data = (reinterpret_cast<typename KT::VecT*>(kvSrc))[kvBlockIdx]; + auto ld_data = (reinterpret_cast<typename KT::VecT const*>(kvSrc))[kvBlockIdx];
138-221: Minor: consider aligning math intrinsics and documenting head_size assumptions.
- You use
fmaxfandstd::exptogether; after switching to__expf/expf, math intrinsics will be consistent.- Head size is hardcoded in traits (128) and checked at runtime; a brief comment in the kernel noting that
KT::kHeadSizemust be 128 would help future maintainers.cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp (1)
1-4: Header-check exception for LFS pointer filesThis .cpp is a Git LFS pointer, so adding a copyright header would break the pointer format. Ensure the header/lint rules exclude cubin pointer files under cubin/ to avoid false positives in CI.
If needed, I can propose an exclude pattern update for your header-check tooling (e.g., regex to skip paths matching
**/cubin/*.cubin.cpp).cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp (1)
155-158: SEPARATE_Q_K_V runtime path: ensure q/k/v pointers are presentWhen using the separate Q/K/V layout, qPtr/kPtr/vPtr must be non-null. Add a defensive check to fail fast if callers forget to set them.
Apply this diff locally within FmhaDispatcher::run:
- else if (mFixedParams.attentionInputLayout == AttentionInputLayout::SEPARATE_Q_K_V) + else if (mFixedParams.attentionInputLayout == AttentionInputLayout::SEPARATE_Q_K_V) { qkvLayout = kernels::QkvLayout::SeparateQkv; + TLLM_CHECK_WITH_INFO(runnerParams.qPtr && runnerParams.kPtr && runnerParams.vPtr, + "SEPARATE_Q_K_V layout requires non-null qPtr/kPtr/vPtr."); }cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h (1)
634-640: Preservemove_col(steps)in all Q/K/V tile variantsTo keep the API symmetric with
rewind_col(int steps)and consistent across sibling tiles, add a defaultstepsparameter to everymove_colthat currently lacks it.Affected locations in
cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h:
- Gmem_tile_q_k_v::move_col (around line 635)
- Gmem_tile_contiguous_kv::move_col (around line 1190)
Proposed patch:
--- a/cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h +++ b/cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h @@ -633,7 +633,9 @@ struct Gmem_tile_q_k_v { - inline __device__ void move_col() + inline __device__ void move_col(int const steps = 1) { - q_k_v_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8); - // Update col_in_bytes_ to ensure load predicates work - col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG; + q_k_v_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG * steps; } @@ -1188,7 +1188,9 @@ struct Gmem_tile_contiguous_kv { - inline __device__ void move_col() + inline __device__ void move_col(int const steps = 1) { - kv_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8); - // Update col_in_bytes_ to ensure load predicates work - col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG; + kv_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8) * steps; + // Update col_in_bytes_ to ensure load predicates work + col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG * steps; }tensorrt_llm/_torch/attention_backend/trtllm.py (1)
205-206: Docstring line exceeds 120 charsBreak the host_total_kv_lens docstring line to satisfy the style limit flagged by static analysis.
- host_total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU. + host_total_kv_lens (torch.Tensor): The tensor to store the total KV lens for + context requests and generation requests, with shape (2) on CPU.cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
820-921: QuantizeCopyInputToFp8Kernel: consider generalizing grid.x token groupinggrid.x uses divUp(total_kv_len, 48). 48 is the LCM of tokens-per-block for QK and V for BF16/FP16, but not for FP32; it still works but may underutilize SMs on some shapes. Consider computing grid.x based on the larger of QK_TOKENS_PER_BLOCK and V_TOKENS_PER_BLOCK (or their LCM) per specialization to balance occupancy and iteration count.
tests/unittest/_torch/test_attention_mla.py (2)
671-673: Broaden enable_flash_mla gating beyond (9, 0)Flash MLA appears to be available on newer arch as well. Consider enabling on SM100/SM120 too to avoid false negatives on capable hardware.
- enable_flash_mla=torch.cuda.get_device_capability() == (9, 0), + enable_flash_mla=torch.cuda.get_device_capability() in {(9, 0), (10, 0), (12, 0)},
750-757: Line length in diagnostic prints exceeds 120Split long f-strings for readability and to satisfy linters.
- print( - f"{backend_name} output mean: {result.abs().mean().item()}, max: {result.abs().max().item()}" - ) - print( - f"Reference output mean: {ref_result.abs().mean().item()}, max: {ref_result.abs().max().item()}" - ) - print( - f"Difference mean: {(result - ref_result).abs().mean().item()}, max: {(result - ref_result).abs().max().item()}" - ) + print(f"{backend_name} output mean: {result.abs().mean().item()}, " + f"max: {result.abs().max().item()}") + print(f"Reference output mean: {ref_result.abs().mean().item()}, " + f"max: {ref_result.abs().max().item()}") + print(f"Difference mean: {(result - ref_result).abs().mean().item()}, " + f"max: {(result - ref_result).abs().max().item()}")cpp/kernels/fmha_v2/setup.py (2)
3649-3660: Clarify gating intent for MLA vs normal return-softmax casesThe new skip_mla_combination fixes earlier over-enumeration. Minor readability tweak: make MLA layout intent explicit and avoid double-negatives.
Apply this diff:
- # for normal attention, we only need contiguous kv as input layout when returning softmax. - skip_combination = return_softmax and input_layout != InputLayout.CONTIGUOUS_Q_KV - # for context mla, we need separate qkv as input layout when returning softmax. - skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V + # For normal attention w/ return_softmax, only CONTIGUOUS_Q_KV is needed. + skip_combination = return_softmax and input_layout != InputLayout.CONTIGUOUS_Q_KV + # For context MLA w/ return_softmax, only SEPARATE_Q_K_V is meaningful. + is_mla_layout = input_layout == InputLayout.SEPARATE_Q_K_V + skip_mla_combination = return_softmax and not is_mla_layout
6367-6391: Fix E712 comparisons to True/False in kernel filter (ruff)Pythonic conditionals improve readability and satisfy lints. Replace explicit True/False comparisons.
Apply this diff:
- # Deepseek MLA (generation 576/512 paged) - or (kspec.sm in [90, 100, 120] + # Deepseek MLA (generation 576/512 paged) + or (kspec.sm in [90, 100, 120] and kspec.dtype in ['bf16', 'e4m3_fp32'] and kspec.head_size == 576 and kspec.head_size_v == 512 and kspec.input_layout == InputLayout.Q_PAGED_KV and kspec.sage_block_sizes is None and kspec.version == 2 - and kspec.cross_mha == False - and kspec.flash_attention == True - and kspec.warp_specialization == False - and kspec.tiled == True) + and not kspec.cross_mha + and kspec.flash_attention + and not kspec.warp_specialization + and kspec.tiled) @@ - # Deepseek MLA (context 192/128 separate-q-k-v) - or (kspec.sm in [90, 100, 120] - and kspec.dtype in ['bf16', 'e4m3_fp32'] + # Deepseek MLA (context 192/128 separate-q-k-v) + or (kspec.sm in [90, 100, 120] + and kspec.dtype in ['bf16', 'e4m3_fp32'] and kspec.head_size == 192 and kspec.head_size_v == 128 and kspec.input_layout == InputLayout.SEPARATE_Q_K_V and kspec.sage_block_sizes is None and kspec.version == 2 - and kspec.cross_mha == False - and kspec.flash_attention == True - and ((kspec.warp_specialization == True and kspec.alibi == False) # sm90 - or (kspec.warp_specialization == False and kspec.tiled == True)) # non-sm90 - and kspec.enable_attn_logit_softcapping == False) + and not kspec.cross_mha + and kspec.flash_attention + and ((kspec.warp_specialization and not kspec.alibi) # sm90 + or (not kspec.warp_specialization and kspec.tiled)) # non-sm90 + and not kspec.enable_attn_logit_softcapping)cpp/tensorrt_llm/common/attentionOp.cpp (1)
2628-2639: Using SEPARATE_Q_K_V and headSizeQkNope for context MLA is appropriateForcing AttentionInputLayout::SEPARATE_Q_K_V for context MLA and plumbed headSizeQkNope matches the kernel-param needs introduced in fmha runners. The headSizeV note is documented; consider an assertion or static check in debug builds if future models diverge.
cpp/tensorrt_llm/kernels/mlaKernels.h (1)
54-104: q/k mutability justified; consider typing FP8 quant buffersYou made q_buf and k_buf mutable because RoPE modifies them in-place; v_buf is const, which is correct. Consider typing the FP8 quant buffers as __nv_fp8_e4m3* instead of void* to avoid repeated casts and improve readability across declarations and call sites.
If changing the header type is too intrusive now, at least add brief comments in callers where casts happen to reduce future confusion.
cpp/tensorrt_llm/thop/attentionOp.h (1)
37-61: Minor consistency: prefer one optional typeThis header mixes std::optionaltorch::Tensor and torch::optionaltorch::Tensor. Both work, but pick one style for consistency (project convention leans towards std::optional in C++17 code).
cpp/tensorrt_llm/thop/attentionOp.cpp (6)
67-85: Run() signature explosion: consider bundling args to improve safety and readabilityThe expanded run signature is getting unwieldy and error-prone. Consider introducing a small POD/struct (e.g., RunArgs) to group the many tensors and scalars. This reduces callsite mistakes, eases evolution, and matches the pattern you already use for enqueue params.
668-681: Workspace device selection is correct, but add capacity check when external buffer is providedIf the provided workspace tensor’s dtype/device mismatches q’s device or is not byte, implicit assumptions could break. Consider validating dtype and device.
Apply this diff:
if (workspace_.has_value()) { if (workspace_.value().numel() < workspace_size) { TLLM_LOG_WARNING("Attention workspace size is not enough, increase the size from %ld bytes to %ld bytes", workspace_.value().numel(), workspace_size); workspace_.value().resize_({workspace_size}); } - workspace = workspace_.value(); + TORCH_CHECK(workspace_.value().dtype() == torch::kByte, + "Workspace tensor must have dtype=Byte"); + TORCH_CHECK(workspace_.value().device() == qkv_or_q.device(), + "Workspace tensor must be on same device as inputs"); + workspace = workspace_.value(); } else { - workspace = torch::empty({workspace_size}, torch::dtype(torch::kByte).device(qkv_or_q.device())); + workspace = torch::empty({workspace_size}, torch::dtype(torch::kByte).device(qkv_or_q.device())); }
469-487: Top-level guards on fused vs separate inputs make sense, but add explicit K/V dtype/device checks earlyYou validate presence of K/V when needed. For earlier/faster failure, mirror the dtype/device checks here (in addition to Runner::run), so errors are caught before dispatch/templating.
Would you like me to draft a minimal guard block here?
285-292: Validate softmax_stats_tensor dtype/shape before casting to float2*You correctly enforce float32 dtype, but since it’s reinterpreted as float2, also assert that numel is even and inner stride aligns to sizeof(float).
Proposed snippet (contextual):
TORCH_CHECK(softmax_stats_tensor->numel() % 2 == 0, "softmax_stats_tensor must have even number of floats for float2");
1-16: Update copyright yearPer coding guidelines, use the current year. Update 1993-2024 to 1993-2025.
Apply this diff:
- * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
683-712: Request: add targeted tests for separate-QKV context MLA across all pathsGiven the behavioral changes (default/KV-reuse/chunked), please add tests that:
- validate K/V dtype/device mismatch raises,
- verify total_kv_len is honored in enqueue,
- cover both SM90 and SM120 config gates (feature flags permitting).
I can scaffold unit tests for test_attention_mla.py covering these scenarios. Want me to draft them?
...HeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp
Show resolved
Hide resolved
|
PR_Github #15632 [ run ] completed with state |
Description
Use separate qkv input layout for context MLA, including default path, kv cache reuse path and chunked context path.
We can eliminate several extra operations like
concatin default path,set_paged_kvin kv cache reuse path andset_chunked_kvin chunked context path. We can also eliminate the copy of tensor V by setting correct stride for it.This PR also adds fmha_v2 kernels with separate qkv support for sm120. After this PR, the support matrix of DeepSeek V3/R1 is:
Benchmark on 8*B200 with nvidia/DeepSeek-R1-0528-FP4, enable FP8 KV Cache, ISL=1K, OSL=2K, num-requests=114688 (following this doc)
The speedup is trivial, because this PR only focus on context phase, while generation phase dominates this test case and the
max_num_tokensis only 2048.P.S. After modifying OSL to 1 and num-requests to 49152, we got 3.4% throughput improvement (76440.4384 VS 79022.3162).
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
New Features
Refactor
Bug Fixes
Chores / Removals
Performance / FP8