99import vllm .envs as envs
1010from vllm .model_executor .custom_op import CustomOp
1111from vllm .platforms import current_platform
12+ from vllm .utils import direct_register_custom_op
1213
1314
1415def is_rocm_aiter_rmsnorm_enabled () -> bool :
15- return current_platform .is_rocm () \
16- and envs .VLLM_ROCM_USE_AITER_RMSNORM \
16+ return envs .VLLM_ROCM_USE_AITER_RMSNORM \
1717 and envs .VLLM_ROCM_USE_AITER
1818
1919
@@ -43,8 +43,8 @@ def fused_add_rms_norm(
4343 return x , residual
4444
4545
46- def rocm_aiter_rms_norm (x : torch .Tensor , weight : torch .Tensor ,
47- variance_epsilon : float ) -> torch .Tensor :
46+ def rocm_aiter_rms_norm_impl (x : torch .Tensor , weight : torch .Tensor ,
47+ variance_epsilon : float ) -> torch .Tensor :
4848 import aiter as rocm_aiter
4949 if x .dim () > 2 :
5050 x_original_shape = x .shape
@@ -55,7 +55,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
5555 return rocm_aiter .rms_norm (x , weight , variance_epsilon )
5656
5757
58- def rocm_aiter_fused_add_rms_norm (
58+ def rocm_aiter_rmsnorm2d_fwd_with_add_impl (
5959 x : torch .Tensor , residual : torch .Tensor , weight : torch .Tensor ,
6060 variance_epsilon : float ) -> tuple [torch .Tensor , torch .Tensor ]:
6161
@@ -74,14 +74,48 @@ def rocm_aiter_fused_add_rms_norm(
7474 return output , residual_out
7575
7676
77- def dispatch_cuda_rmsnorm_func (add_residual : bool ):
78- if add_residual :
79- if is_rocm_aiter_rmsnorm_enabled ():
80- return rocm_aiter_fused_add_rms_norm
81- return fused_add_rms_norm
77+ def rocm_aiter_rms_norm_fake (x : torch .Tensor , weight : torch .Tensor ,
78+ variance_epsilon : float ) -> torch .Tensor :
79+ return torch .empty_like (x )
80+
81+
82+ def rocm_aiter_rmsnorm2d_fwd_with_add_fake (
83+ x : torch .Tensor , residual : torch .Tensor , weight : torch .Tensor ,
84+ variance_epsilon : float ) -> tuple [torch .Tensor , torch .Tensor ]:
85+ return torch .empty_like (x ), torch .empty_like (residual )
86+
87+
88+ if current_platform .is_rocm ():
89+ direct_register_custom_op (
90+ op_name = "rocm_aiter_rms_norm" ,
91+ op_func = rocm_aiter_rms_norm_impl ,
92+ mutates_args = [],
93+ fake_impl = rocm_aiter_rms_norm_fake ,
94+ dispatch_key = current_platform .dispatch_key ,
95+ )
96+
97+ direct_register_custom_op (
98+ op_name = "rocm_aiter_rmsnorm2d_fwd_with_add" ,
99+ op_func = rocm_aiter_rmsnorm2d_fwd_with_add_impl ,
100+ mutates_args = [],
101+ fake_impl = rocm_aiter_rmsnorm2d_fwd_with_add_fake ,
102+ dispatch_key = current_platform .dispatch_key ,
103+ )
104+
105+
106+ def dispatch_rocm_rmsnorm_func (with_fused_add : bool , dtype : torch .dtype ):
107+ use_aiter = is_rocm_aiter_rmsnorm_enabled () and dtype in [
108+ torch .float16 , torch .bfloat16
109+ ]
110+
111+ if use_aiter and with_fused_add :
112+ return torch .ops .vllm .rocm_aiter_rmsnorm2d_fwd_with_add
113+ if use_aiter :
114+ return torch .ops .vllm .rocm_aiter_rms_norm
82115
83- if is_rocm_aiter_rmsnorm_enabled ():
84- return rocm_aiter_rms_norm
116+ # fall back to CUDA implementation
117+ if with_fused_add :
118+ return fused_add_rms_norm
85119 return rms_norm
86120
87121
@@ -114,6 +148,13 @@ def __init__(
114148 self .weight = torch .ones (hidden_size )
115149 if self .has_weight :
116150 self .weight = nn .Parameter (self .weight )
151+ weight_dtype = self .weight .data .dtype
152+
153+ if current_platform .is_rocm ():
154+ self .rocm_norm_func = dispatch_rocm_rmsnorm_func (
155+ with_fused_add = False , dtype = weight_dtype )
156+ self .rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func (
157+ with_fused_add = True , dtype = weight_dtype )
117158
118159 def forward_native (
119160 self ,
@@ -162,13 +203,27 @@ def forward_cuda(
162203 return self .forward_native (x , residual )
163204
164205 add_residual = residual is not None
165- norm_func = dispatch_cuda_rmsnorm_func (add_residual )
206+ if add_residual :
207+ return fused_add_rms_norm (x , residual , self .weight .data ,
208+ self .variance_epsilon )
209+ else :
210+ return rms_norm (x , self .weight .data , self .variance_epsilon )
211+
212+ def forward_hip (
213+ self ,
214+ x : torch .Tensor ,
215+ residual : Optional [torch .Tensor ] = None ,
216+ ) -> Union [torch .Tensor , tuple [torch .Tensor , torch .Tensor ]]:
217+ if self .variance_size_override is not None :
218+ return self .forward_native (x , residual )
166219
220+ add_residual = residual is not None
167221 if add_residual :
168- return norm_func (x , residual , self .weight .data ,
169- self .variance_epsilon )
222+ return self . rocm_norm_func_with_add (x , residual , self .weight .data ,
223+ self .variance_epsilon )
170224 else :
171- return norm_func (x , self .weight .data , self .variance_epsilon )
225+ return self .rocm_norm_func (x , self .weight .data ,
226+ self .variance_epsilon )
172227
173228 def forward_xpu (
174229 self ,
0 commit comments