[ROCm] Add aiter tkw1 kernel for Llama4 fp8#16727
[ROCm] Add aiter tkw1 kernel for Llama4 fp8#16727vllm-bot merged 14 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
88e60fb to
6659b99
Compare
Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
vllm/envs.py
Outdated
| VLLM_ROCM_USE_AITER_LINEAR: bool = True | ||
| VLLM_ROCM_USE_AITER_MOE: bool = True | ||
| VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False | ||
| VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE: bool = False |
There was a problem hiding this comment.
Can we make the env name more align with the kernel name , in this case, to include tkw1 in the name?
|
|
||
|
|
||
| def is_rocm_aiter_channel_scaled_moe_enabled() -> bool: | ||
| return is_rocm_aiter_moe_enabled() and \ |
There was a problem hiding this comment.
Does this tkw1 enablement need to depend on is_rocm_aiter_moe_enabled() ?
There was a problem hiding this comment.
In this enablement we are following the block_scaled_moe case in using VLLM_ROCM_USE_AITER_MOE as a master switch for enabling MoE ops, to stay consistent with the other aiter kernels.
| if activation_str == "silu": | ||
| activation = ActivationType.Silu | ||
| elif activation_str == "gelu": | ||
| activation = ActivationType.Gelu | ||
| else: | ||
| activation = ActivationType.Silu |
There was a problem hiding this comment.
Can be simplified to one-liner ?
| if activation_str == "silu": | |
| activation = ActivationType.Silu | |
| elif activation_str == "gelu": | |
| activation = ActivationType.Gelu | |
| else: | |
| activation = ActivationType.Silu | |
| activation = ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu |
There was a problem hiding this comment.
Do we need an additional wrapper for the _tkw1 kernel, given that it’s just a kernel call plus an activation type conversion? the activation type can also used by other branches / kernel calls?
There was a problem hiding this comment.
We are wrapping the kernel call because in our future PR addressing the enablement of torch compile for aiter MoE kernels, we will be using wrappers to register the aiter ops, and so we thought to leave it here for now.
| # # All AITER Fused MoE kernels are expecting the following datatypes | ||
| # topk_weights = topk_weights.to(torch.float32) | ||
| # topk_ids = topk_ids.to(torch.int32) |
There was a problem hiding this comment.
| # # All AITER Fused MoE kernels are expecting the following datatypes | |
| # topk_weights = topk_weights.to(torch.float32) | |
| # topk_ids = topk_ids.to(torch.int32) |
| # topk_weights = topk_weights.to(torch.float32) | ||
| # topk_ids = topk_ids.to(torch.int32) | ||
|
|
||
| return rocm_aiter_asm_moe_tkw1(hidden_states, |
There was a problem hiding this comment.
Let's assert apply_router_weight_on_input=True or do the if branch check when calling the _tkw1 kernel? btw, we should have some comments to illustrate the difference between _tkw1 kernel and other aiter kernels. The difference is on applying topk_weights on the output of the first GEMM or the second GEMM
| if activation_str == "silu": | ||
| activation = ActivationType.Silu | ||
| elif activation_str == "gelu": | ||
| activation = ActivationType.Gelu | ||
| else: | ||
| activation = ActivationType.Silu |
There was a problem hiding this comment.
Do we need an additional wrapper for the _tkw1 kernel, given that it’s just a kernel call plus an activation type conversion? the activation type can also used by other branches / kernel calls?
| and layer.activation == "silu" and layer.expert_map is None): | ||
| return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) | ||
| elif quant_config._is_fp8_w8a8(weight_quant, input_quant): | ||
| if is_rocm_aiter_channel_scaled_moe_enabled(): |
There was a problem hiding this comment.
tkw1 is not a general support of FP8 FMOE channel / rowwise scaling, it only supports the case when apply_router_weight_on_input =True
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
…E_AITER_FP8_BLOCK_SCALED_MOE and VLLM_ROCM_USE_AITER_FP8_TKW1_MOE Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
…E_AITER_FP8_BLOCK_SCALED_MOE and VLLM_ROCM_USE_AITER_FP8_TKW1_MOE Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
| ) | ||
|
|
||
| if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: | ||
| # TODO: verify this code path for DeepSeekV3 |
There was a problem hiding this comment.
can we verify before landing?
There was a problem hiding this comment.
Verified: Will remove the comment.
2025-04-18:10:35:16 INFO [loggers.evaluation_tracker:272] Output path not provided, skipping saving results aggregated
vllm (pretrained=deepseek-ai/DeepSeek-V3,tensor_parallel_size=8,max_model_len=30000,gpu_memory_utilization=0.8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
| Tasks | Version | Filter | n-shot | Metric | Value | Stderr | ||
|---|---|---|---|---|---|---|---|---|
| gsm8k | 3 | flexible-extract | 5 | exact_match | ↑ | 0.9492 | ± | 0.006 |
| strict-match | 5 | exact_match | ↑ | 0.9500 | ± | 0.006 |
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
SageMoore
left a comment
There was a problem hiding this comment.
Looks reasonable. Just a few nits.
| layer.w2_weight = torch.nn.Parameter(shuffled_w2, | ||
| requires_grad=False) | ||
|
|
||
| if self.use_rocm_aiter_moe: |
There was a problem hiding this comment.
Nit: Can you merge these into one if statement?
There was a problem hiding this comment.
Will do. Thanks for pointing this out.
| is_rocm_aiter_moe_enabled) | ||
|
|
||
| # Property to determine if AITER is used | ||
| self.use_rocm_aiter_moe = is_rocm_aiter_moe_enabled() |
There was a problem hiding this comment.
Nit: Do you need to store this in the class? It doesn't look like you are using it outside of this function.
There was a problem hiding this comment.
You're right. Updated this along with the merged if statement.
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
This PR enables aiter's tkw1 quantized MoE kernel to improve inferencing performance of compressed tensor Llama4 quantized with FP8. We have also revamped the aiter's MoE kernel dispatching to automatically choose the suitable AITER Fused MoE kernel without needing to set flags for kernel selection. Users only need to specify
VLLM_ROCM_USE_AITER=1andVLLM_ROCM_USE_AITER_MOE=1to activate aiter's MoE kernels, and theVLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOEflag is removed.Note: torch.compile isn't supported in this PR yet, and the performance numbers are attained with V1 eager mode. The enablement of V1 torch compile for aiter MoE kernels will be addressed in a separate PR.
Llama4 Maverick FP8 throughput benchmarks
Llama4 Maverick FP8 latency benchmarks
Text Generation Response
lm_eval Results
V1 without aiter, eager mode
vllm (pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=4,max_model_len=30000,enforce_eager=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
V1 with aiter, eager mode
vllm (pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=4,max_model_len=30000,enforce_eager=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Reduce complexity of selecting AITER Fused MoE kernel
As the number of AITER Flags have increased, we have revamped the condition to pick the AITER Fused MoE kernel without the need of any flags. So
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE. User only need to specifyVLLM_ROCM_USE_AITER=1andVLLM_ROCM_USE_AITER_MOE=1`We have validated the code path of other models with the latest AITER fused moe selection logic:
mistralai_Mixtral-8x7B-Instruct-v0.1_V0
vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=1,max_model_len=30000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
mistralai_Mixtral-8x7B-Instruct-v0.1_FP8_V0
vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=1,max_model_len=30000,quantization=fp8,kv_cache_dtype=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
despseek-ai_DeepSeek-V3
vllm (pretrained=deepseek-ai/DeepSeek-V3,tensor_parallel_size=8,max_model_len=30000,gpu_memory_utilization=0.8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto