-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Fix max_seq_len_k in trtllm_mha attention backend #9416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Qiaolin-Yu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses an issue in the trtllm_mha attention backend where the max_seq_len_k parameter was being inconsistently handled. The fix ensures that max_seq_len_k within the forward metadata accurately reflects the maximum sequence length of the current batch, while the underlying attention kernel consistently receives the global max_context_len for proper operation. This correction aims to improve the stability and correctness of the attention mechanism, particularly when using CUDA graphs.
Highlights
- Corrected
max_seq_len_kcalculation: Themax_seq_len_kattribute within the forward metadata is now dynamically calculated based on the maximum sequence length present in the current batch, rather than using a global maximum context length. - Consistent use of
max_context_lenin attention kernel calls: Themax_seq_lenandmax_kv_lenparameters passed to the core attention kernel (_attn_fwd) in bothforward_decodeandforward_extendmethods now consistently useself.max_context_len, ensuring the kernel operates with the correct overall context capacity. - Enhanced stability and correctness: These changes resolve an inconsistency in how sequence length parameters were handled, leading to improved stability and correctness for the
trtllm_mhaattention backend, especially when leveraging CUDA graphs.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes an issue with the max_seq_len_k parameter in the trtllm_mha attention backend. The changes properly distinguish between the maximum sequence length within a batch and the model's overall maximum context length. The modifications ensure that metadata.max_seq_len_k accurately reflects the current batch's maximum sequence length, which is semantically correct. Most importantly, the calls to the underlying flashinfer kernels in forward_decode and forward_extend are updated to use self.max_context_len. This is the correct approach, as these kernels likely require the maximum potential length for which the KV cache structures are sized, rather than the dynamic maximum length of the current batch. The fix is logical, well-targeted, and resolves the inconsistency in the original implementation.
yyihuang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. The max_seq_len should be >= the page table stride for trtllm-gen family of attention kernels.
We could continue on flashinfer-ai/flashinfer#1407 to remove this param, after the new cubin publishing & launcher params refactor flashinfer-ai/flashinfer#1518. cc @yzh119 for confirming this todo item.
Motivation
Modifications
Accuracy Tests
Accuracy: 0.7089646464646465
Benchmarking and Profiling
Checklist