-
-
Notifications
You must be signed in to change notification settings - Fork 12.1k
[Kernel] Flashinfer MLA (trtllm-gen) decode kernel integration #21078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new FlashInfer MLA (Multi-LoRA Attention) decode kernel for the vLLM V1 engine. There are critical inconsistencies between the backend implementation and its corresponding test file regarding the shape of the kv_cache tensor and the function signature of the FlashInfer kernel being called. These issues need to be addressed.
b7177bb to
afe27ce
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: hjjq <hanjieq@nvidia.com>
|
@LucasWilkinson this is ready for review. Thanks. |
|
Moving to draft while investigating performance |
|
This pull request has merge conflicts that must be resolved before it can be |
|
@farazkh80 for review. |
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: hjjq <hanjieq@nvidia.com>
|
@LucasWilkinson could you please review? |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: hjjq <hanjieq@nvidia.com>
|
Thanks @MatthewBonanni and @LucasWilkinson, I've made the changes and verified correctness with gsm8k. I've also reverted @mgoin 's changes so that |
|
CI on main that confirms these test failures are not caused by this PR: |
|
@hjjq @LucasWilkinson the remaining CI failures are all failed on main |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: hjjq <hanjieq@nvidia.com>
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
mgoin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
…project#21078) Signed-off-by: hjjq <hanjieq@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
…project#21078) Signed-off-by: hjjq <hanjieq@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
…project#21078) Signed-off-by: hjjq <hanjieq@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…project#21078) Signed-off-by: hjjq <hanjieq@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
E2E Benchmark:
TRTLLM-gen (via Flashinfer) MLA:
cutlass MLA (VLLM_ATTENTION_BACKEND=CUTLASS_MLA):
Kernel-only microbenchmark:
Under high-concurrency, trtllm-gen is 25% faster than cutlass. But E2E isn't bottlenecked by MLA, therefore E2E speedup is not as significant.
Accuracy tests:
commands:
gsm8k:
gpqa:
mmlu: