-
-
Notifications
You must be signed in to change notification settings - Fork 12.1k
[Kernel] Added flashinfer fp8 per-tensor gemms #22895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for FlashInfer's FP8 GEMM kernels, which is expected to improve performance, particularly for large batch sizes. The changes primarily involve refactoring the GEMM dispatch logic to accommodate a new 'flashinfer' backend and adding an optimization to pre-calculate combined scales. While the implementation is largely sound, I've identified a critical issue in the new FlashInfer wrapper where the output tensor is not reshaped, potentially causing shape mismatches for inputs with more than two dimensions.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Julien Lin <jullin@nvidia.com>
4738c59 to
26ddfd9
Compare
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
|
@nvjullin The Blackwell Test failures look clearly related https://buildkite.com/vllm/ci/builds/27727/steps/canvas?jid=0198c767-0112-49f4-9f26-c9fef601374c#0198c767-0112-49f4-9f26-c9fef601374c/98-3479 |
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
|
@mgoin the remaining errors are all something about huggingface gateway timeout |
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor notes.
This might be a bit too urgent but in general we should really improve the fp8 scaled_mm dispatching. I started a draft pr #19434 but never got around to it.
|
@nvjullin Please address the comments and rebase this PR. Thanks! |
Signed-off-by: Julien Lin <jullin@nvidia.com>
09089d8 to
9a847f7
Compare
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
No I think it's out of the scope of this PR. But if you look at dispatching for int8 or Marlin/Machete, that's closer to something we want to do in general when we want to dispatch between multiple possible implementations. |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Julien Lin <jullin@nvidia.com>
mgoin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM to get in, thanks. We should follow up with using an Enum instead of raw strings
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: tc-mb <caitianchi@modelbest.cn>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Purpose
Added fp8 gemms from flashinfer.
The added gemms have better or same perf as the original gemms so we use it as the default.
For gemm sizes with small M, the added gemms are marginally faster.
For gemm sizes with large M, the added gemms are much faster.
These are the results for llama3 ISL=OSL=1024 concurrency=128 max_num_batched_tokens=8192 TP1.
As expected, TPOT is roughly the same but TTFT improved by ~13%.
Requires flashinfer autotuning, so depends on #22346(merged).Funcionality depends on flashinfer PR flashinfer-ai/flashinfer#1479(merged).Perf numbers depends on flashinfer PR flashinfer-ai/flashinfer#1491(merged).Requires next flashinfer release including aforementeiond PRs and vllm updating flashinfer version.(updated flashinfer)old
new
lm_eval shows
Test Plan
Tests to be added.
Test Result
(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.