Skip to content

Commit 033d6a9

Browse files
vllmellmFeiDaLI
authored andcommitted
[ROCm][Bugfix] Fix Aiter RMSNorm (vllm-project#23412)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
1 parent 3e82778 commit 033d6a9

File tree

3 files changed

+108
-36
lines changed

3 files changed

+108
-36
lines changed

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
vllm_topk_softmax)
1414
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
1515
is_rocm_aiter_moe_enabled)
16-
from vllm.model_executor.layers.layernorm import (
17-
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
18-
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
16+
from vllm.model_executor.layers.layernorm import (RMSNorm,
17+
dispatch_rocm_rmsnorm_func,
18+
fused_add_rms_norm, rms_norm)
1919
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2020
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
2121
from vllm.platforms import current_platform
2222

23+
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
24+
2325

2426
# Registered subclass for test
2527
@CustomOp.register("relu3")
@@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
149151

150152

151153
@pytest.mark.parametrize("add_residual", [True, False])
154+
@pytest.mark.parametrize("dtype",
155+
[torch.float32, torch.float16, torch.bfloat16])
152156
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
153157
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
154158
@pytest.mark.skipif(not current_platform.is_rocm(),
155159
reason="AITER is a feature exclusive for ROCm")
156-
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
157-
use_rocm_aiter_norm: str, monkeypatch):
160+
def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype,
161+
use_rocm_aiter: str, use_rocm_aiter_norm: str,
162+
monkeypatch):
158163
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
159164
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
160-
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
161-
162-
if not add_residual:
163-
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
164-
use_rocm_aiter_norm):
165-
assert rms_norm_func == rocm_aiter_rms_norm
166-
else:
167-
assert rms_norm_func == rms_norm
168-
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
169-
use_rocm_aiter_norm):
170-
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
171-
else:
165+
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
166+
167+
should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \
168+
and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES
169+
170+
if add_residual and should_use_rocm_aiter:
171+
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
172+
elif should_use_rocm_aiter:
173+
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
174+
elif add_residual:
172175
assert rms_norm_func == fused_add_rms_norm
176+
else:
177+
assert rms_norm_func == rms_norm

vllm/model_executor/layers/layernorm.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import vllm.envs as envs
1010
from vllm.model_executor.custom_op import CustomOp
1111
from vllm.platforms import current_platform
12+
from vllm.utils import direct_register_custom_op
1213

1314

1415
def is_rocm_aiter_rmsnorm_enabled() -> bool:
15-
return current_platform.is_rocm() \
16-
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
16+
return envs.VLLM_ROCM_USE_AITER_RMSNORM \
1717
and envs.VLLM_ROCM_USE_AITER
1818

1919

@@ -43,8 +43,8 @@ def fused_add_rms_norm(
4343
return x, residual
4444

4545

46-
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
47-
variance_epsilon: float) -> torch.Tensor:
46+
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
47+
variance_epsilon: float) -> torch.Tensor:
4848
import aiter as rocm_aiter
4949
if x.dim() > 2:
5050
x_original_shape = x.shape
@@ -55,7 +55,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
5555
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
5656

5757

58-
def rocm_aiter_fused_add_rms_norm(
58+
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
5959
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
6060
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
6161

@@ -74,14 +74,48 @@ def rocm_aiter_fused_add_rms_norm(
7474
return output, residual_out
7575

7676

77-
def dispatch_cuda_rmsnorm_func(add_residual: bool):
78-
if add_residual:
79-
if is_rocm_aiter_rmsnorm_enabled():
80-
return rocm_aiter_fused_add_rms_norm
81-
return fused_add_rms_norm
77+
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
78+
variance_epsilon: float) -> torch.Tensor:
79+
return torch.empty_like(x)
80+
81+
82+
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
83+
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
84+
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
85+
return torch.empty_like(x), torch.empty_like(residual)
86+
87+
88+
if current_platform.is_rocm():
89+
direct_register_custom_op(
90+
op_name="rocm_aiter_rms_norm",
91+
op_func=rocm_aiter_rms_norm_impl,
92+
mutates_args=[],
93+
fake_impl=rocm_aiter_rms_norm_fake,
94+
dispatch_key=current_platform.dispatch_key,
95+
)
96+
97+
direct_register_custom_op(
98+
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
99+
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
100+
mutates_args=[],
101+
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
102+
dispatch_key=current_platform.dispatch_key,
103+
)
104+
105+
106+
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
107+
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
108+
torch.float16, torch.bfloat16
109+
]
110+
111+
if use_aiter and with_fused_add:
112+
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
113+
if use_aiter:
114+
return torch.ops.vllm.rocm_aiter_rms_norm
82115

83-
if is_rocm_aiter_rmsnorm_enabled():
84-
return rocm_aiter_rms_norm
116+
# fall back to CUDA implementation
117+
if with_fused_add:
118+
return fused_add_rms_norm
85119
return rms_norm
86120

87121

@@ -114,6 +148,13 @@ def __init__(
114148
self.weight = torch.ones(hidden_size)
115149
if self.has_weight:
116150
self.weight = nn.Parameter(self.weight)
151+
weight_dtype = self.weight.data.dtype
152+
153+
if current_platform.is_rocm():
154+
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
155+
with_fused_add=False, dtype=weight_dtype)
156+
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
157+
with_fused_add=True, dtype=weight_dtype)
117158

118159
def forward_native(
119160
self,
@@ -162,13 +203,27 @@ def forward_cuda(
162203
return self.forward_native(x, residual)
163204

164205
add_residual = residual is not None
165-
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
206+
if add_residual:
207+
return fused_add_rms_norm(x, residual, self.weight.data,
208+
self.variance_epsilon)
209+
else:
210+
return rms_norm(x, self.weight.data, self.variance_epsilon)
211+
212+
def forward_hip(
213+
self,
214+
x: torch.Tensor,
215+
residual: Optional[torch.Tensor] = None,
216+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
217+
if self.variance_size_override is not None:
218+
return self.forward_native(x, residual)
166219

220+
add_residual = residual is not None
167221
if add_residual:
168-
return norm_func(x, residual, self.weight.data,
169-
self.variance_epsilon)
222+
return self.rocm_norm_func_with_add(x, residual, self.weight.data,
223+
self.variance_epsilon)
170224
else:
171-
return norm_func(x, self.weight.data, self.variance_epsilon)
225+
return self.rocm_norm_func(x, self.weight.data,
226+
self.variance_epsilon)
172227

173228
def forward_xpu(
174229
self,

vllm/platforms/rocm.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,23 +322,35 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
322322

323323
@classmethod
324324
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
325+
from vllm.config.compilation import CUDAGraphMode
326+
325327
cache_config = vllm_config.cache_config
328+
compilation_config = vllm_config.compilation_config
329+
parallel_config = vllm_config.parallel_config
330+
is_eager_execution = compilation_config == CUDAGraphMode.NONE
331+
332+
use_v1 = envs.VLLM_USE_V1
333+
use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
334+
envs.VLLM_ROCM_USE_AITER_RMSNORM
335+
326336
if cache_config and cache_config.block_size is None:
327337
cache_config.block_size = 16
328338

329-
parallel_config = vllm_config.parallel_config
330339
if parallel_config.worker_cls == "auto":
331340
if vllm_config.speculative_config:
332-
if not envs.VLLM_USE_V1:
341+
if not use_v1:
333342
raise NotImplementedError(
334343
"Speculative decoding is not supported on vLLM V0.")
335344
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
336345
else:
337-
if envs.VLLM_USE_V1:
346+
if use_v1:
338347
parallel_config.worker_cls = \
339348
"vllm.v1.worker.gpu_worker.Worker"
340349
else:
341350
parallel_config.worker_cls = "vllm.worker.worker.Worker"
351+
# Aiter rms norm perform best when CUDA Graph capture is enabled.
352+
if use_v1 and use_aiter_rms_norm and not is_eager_execution:
353+
compilation_config.custom_ops.append("+rms_norm")
342354

343355
@classmethod
344356
def verify_model_arch(cls, model_arch: str) -> None:

0 commit comments

Comments
 (0)