-
-
Notifications
You must be signed in to change notification settings - Fork 12.1k
[ROCm][Bugfix] Fix Aiter RMSNorm #23412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm][Bugfix] Fix Aiter RMSNorm #23412
Conversation
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move some stuff around please otherwise looks good!
vllm/config/__init__.py
Outdated
| if "none" in self.compilation_config.custom_ops: | ||
| self.compilation_config.custom_ops.remove("none") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to remove "none" from the list, "+rms_norm" overrides it. Also, could you move this to a platform-specific config update (current_platform.check_and_update_config)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this enables the custom ROCm RmsNorm, are we sure we want to do that (even if AITER is not enabled)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg I moved it and added a check to append "+rms_norm" only when aiter is enabled
|
|
||
| add_residual = residual is not None | ||
| norm_func = dispatch_cuda_rmsnorm_func(add_residual) | ||
| norm_func = dispatch_rmsnorm_func(add_residual, self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of dispatching inside forward_cuda, please add forward_rocm. Also the dispatching can happen inside __init__ instead of during runtime, that's the benefit of the CustomOp abstraction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg i have moved the dispatch to the init. However, since ROCm and CUDA share the same logical flow, it would be redundant to implement the forward_hip with the same code. The CustomOp class assumes that HIP ops are compatible with CUDA ops. forward_hip defaults to the CUDA implementation and it is documented in the class. Perhaps we can implement a forward "helper" function and explicitly call it in both forward_cuda and forward_hip if that would make it clearer for other developers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be redundant to implement the forward_hip with the same code
I think here separation of concerns is more important than DRY (don't repeat yourself).
forward_hipdefaults to the CUDA implementation and it is documented in the class
Yes, but here we're changing the default behavior, so it's okay to override the forward_hip method.
|
Also, @vllmellm could we do a follow-up where we enable aiter_rmsnorm + quant -> aiter_rmsnorm_quant fusion? I heard there were aiter kernels for this |
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Sure thing! we'll just have to profile it first. |
vllm/platforms/rocm.py
Outdated
| else: | ||
| parallel_config.worker_cls = "vllm.worker.worker.Worker" | ||
|
|
||
| if use_v1 and use_aiter_rms_norm and not enforce_eager: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does enforce_eager matter here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is solely for performance reasons. Aiter rms norm functions perform best when cuda graph capture is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case please add a comment and use the cudagraph_mode property: cudagraph_mode != CUDAGraphMode.NONE
| return self.forward_native(x, residual) | ||
|
|
||
| add_residual = residual is not None | ||
| norm_func = dispatch_cuda_rmsnorm_func(add_residual) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you still split up the CUDA and hip paths here? Even though there will be slight repetition, I think it will be more clear what's happening in each platform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(to be clear only hip should use the dispatching and CUDA should call the vllm custom ops directly)
| weight_dtype = self.weight.data.dtype | ||
|
|
||
| self.norm_func = dispatch_rmsnorm_func(with_fused_add=False, | ||
| dtype=weight_dtype) | ||
| self.norm_func_with_add = dispatch_rmsnorm_func(with_fused_add=True, | ||
| dtype=weight_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you guard this under is_rocm?
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A quick nit and please avoid using enforce_eager directly
| weight_dtype = self.weight.data.dtype | ||
|
|
||
| if current_platform.is_rocm(): | ||
| self.rocm_aiter_norm_func = dispatch_rocm_rmsnorm_func( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: self.rocm_norm_fun
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg done!
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Head branch was pushed to by a user without write access
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
Currently in V1, RMSNorm was as CustomOps class.
The default behaviour of V1 when torch compile is enabled, the custom_ops is set to "none", this forces the rms norm class to call forward_native.
Since AITER RMS_NORM is faster, we will want to utilize the forward_cuda instead. So we have appended the "rms_norm" to the list
self.compilation_config.custom_opsif user enable AITER RMS_NORM.This PR fixes the AITER RMS NORM V1 custom op registration and added type checking as AITER RMSNORM does not support float32 inputs.
Test Plan
Run Lmeval on a select models and compare aiter and non-aiter kernels.
Test Result
(Optional) Documentation Update
None
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.