Skip to content

Conversation

@vllmellm
Copy link
Contributor

@vllmellm vllmellm commented Aug 22, 2025

Purpose

Currently in V1, RMSNorm was as CustomOps class.

The default behaviour of V1 when torch compile is enabled, the custom_ops is set to "none", this forces the rms norm class to call forward_native.

Since AITER RMS_NORM is faster, we will want to utilize the forward_cuda instead. So we have appended the "rms_norm" to the list self.compilation_config.custom_ops if user enable AITER RMS_NORM.

This PR fixes the AITER RMS NORM V1 custom op registration and added type checking as AITER RMSNORM does not support float32 inputs.

Test Plan

Run Lmeval on a select models and compare aiter and non-aiter kernels.

Test Result

Model Task Version Filter n-shot Metric Aiter On Value Aiter Stderr Aiter Off Value Aiter Off Stderr
deepseek-ai/DeepSeek-V3 gsm8k 3 flexible-extract 5 exact_match 0.9530 0.0058 0.9492 0.0060
deepseek-ai/DeepSeek-V3 gsm8k 3 strict-match 5 exact_match 0.9515 0.0059 0.9492 0.0060
llama/Llama-4-Scout-17B-16E gsm8k 3 flexible-extract 5 exact_match 0.8211 0.0106 0.8271 0.0104
llama/Llama-4-Scout-17B-16E gsm8k 3 strict-match 5 exact_match 0.8203 0.0106 0.8256 0.0105
EmbeddedLLM/Qwen2.5-1.5B-Instruct-FP8-Dynamic gsm8k 3 flexible-extract 5 exact_match 0.5186 0.0138 0.5186 0.0138
EmbeddedLLM/Qwen2.5-1.5B-Instruct-FP8-Dynamic gsm8k 3 strict-match 5 exact_match 0.3397 0.0130 0.3404 0.0131
Qwen/Qwen3-32B gsm8k 3 flexible-extract 5 exact_match 0.6224 0.0134 0.6240 0.0133
Qwen/Qwen3-32B gsm8k 3 strict-match 5 exact_match 0.7354 0.0122 0.7362 0.0121

(Optional) Documentation Update

None

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm changed the title register aiter norm funcs as custom ops [ROCm][Bugfix] Fix Aiter RMSNorm Aug 22, 2025
@mergify mergify bot added the rocm Related to AMD ROCm label Aug 22, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move some stuff around please otherwise looks good!

Comment on lines 3622 to 3623
if "none" in self.compilation_config.custom_ops:
self.compilation_config.custom_ops.remove("none")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to remove "none" from the list, "+rms_norm" overrides it. Also, could you move this to a platform-specific config update (current_platform.check_and_update_config)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this enables the custom ROCm RmsNorm, are we sure we want to do that (even if AITER is not enabled)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg I moved it and added a check to append "+rms_norm" only when aiter is enabled


add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
norm_func = dispatch_rmsnorm_func(add_residual, self.dtype)
Copy link
Collaborator

@ProExpertProg ProExpertProg Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of dispatching inside forward_cuda, please add forward_rocm. Also the dispatching can happen inside __init__ instead of during runtime, that's the benefit of the CustomOp abstraction.

Copy link
Contributor Author

@vllmellm vllmellm Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg i have moved the dispatch to the init. However, since ROCm and CUDA share the same logical flow, it would be redundant to implement the forward_hip with the same code. The CustomOp class assumes that HIP ops are compatible with CUDA ops. forward_hip defaults to the CUDA implementation and it is documented in the class. Perhaps we can implement a forward "helper" function and explicitly call it in both forward_cuda and forward_hip if that would make it clearer for other developers?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be redundant to implement the forward_hip with the same code

I think here separation of concerns is more important than DRY (don't repeat yourself).

forward_hip defaults to the CUDA implementation and it is documented in the class

Yes, but here we're changing the default behavior, so it's okay to override the forward_hip method.

@ProExpertProg
Copy link
Collaborator

Also, @vllmellm could we do a follow-up where we enable aiter_rmsnorm + quant -> aiter_rmsnorm_quant fusion? I heard there were aiter kernels for this

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm
Copy link
Contributor Author

vllmellm commented Aug 29, 2025

Also, @vllmellm could we do a follow-up where we enable aiter_rmsnorm + quant -> aiter_rmsnorm_quant fusion? I heard there were aiter kernels for this

Sure thing! we'll just have to profile it first.

else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

if use_v1 and use_aiter_rms_norm and not enforce_eager:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does enforce_eager matter here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is solely for performance reasons. Aiter rms norm functions perform best when cuda graph capture is enabled.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case please add a comment and use the cudagraph_mode property: cudagraph_mode != CUDAGraphMode.NONE

return self.forward_native(x, residual)

add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you still split up the CUDA and hip paths here? Even though there will be slight repetition, I think it will be more clear what's happening in each platform.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(to be clear only hip should use the dispatching and CUDA should call the vllm custom ops directly)

Comment on lines 150 to 155
weight_dtype = self.weight.data.dtype

self.norm_func = dispatch_rmsnorm_func(with_fused_add=False,
dtype=weight_dtype)
self.norm_func_with_add = dispatch_rmsnorm_func(with_fused_add=True,
dtype=weight_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you guard this under is_rocm?

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick nit and please avoid using enforce_eager directly

weight_dtype = self.weight.data.dtype

if current_platform.is_rocm():
self.rocm_aiter_norm_func = dispatch_rocm_rmsnorm_func(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: self.rocm_norm_fun

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@ProExpertProg ProExpertProg enabled auto-merge (squash) September 9, 2025 20:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 9, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
auto-merge was automatically disabled September 10, 2025 06:59

Head branch was pushed to by a user without write access

@DarkLight1337 DarkLight1337 merged commit 7c195d4 into vllm-project:main Sep 10, 2025
39 checks passed
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants