Skip to content

Commit 0c51a59

Browse files
committed
[feat] Introduce QKNormRoPEAttention class for enhanced attention mechanism in Qwen3 model
- Added QKNormRoPEAttention class to apply QK normalization and RoPE. - Replaced Qwen3Attention with QKNormRoPEAttention in Qwen3DecoderLayer and Qwen3MoEDecoderLayer. - Removed unused imports and code related to the previous attention implementation. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
1 parent a060e12 commit 0c51a59

File tree

3 files changed

+115
-100
lines changed

3 files changed

+115
-100
lines changed

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 2 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -4,115 +4,19 @@
44
from torch import nn
55
from transformers import Qwen3Config
66

7-
from tensorrt_llm.functional import PositionEmbeddingType
8-
97
from ..attention_backend import AttentionMetadata
10-
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
118
from ..model_config import ModelConfig
12-
from ..modules.attention import Attention
139
from ..modules.decoder_layer import DecoderLayer
1410
from ..modules.embedding import Embedding
1511
from ..modules.gated_mlp import GatedMLP
1612
from ..modules.linear import TensorParallelMode
17-
from ..modules.multi_stream_utils import maybe_execute_in_parallel
13+
from ..modules.qk_norm_attention import QKNormRoPEAttention
1814
from ..modules.rms_norm import RMSNorm
1915
from ..speculative import SpecMetadata
2016
from .modeling_speculative import SpecDecOneEngineForCausalLM
2117
from .modeling_utils import DecoderModel, register_auto_model
2218

2319

24-
class Qwen3Attention(Attention):
25-
26-
def __init__(
27-
self,
28-
model_config: ModelConfig[Qwen3Config],
29-
layer_idx: Optional[int] = None,
30-
fuse_qk_norm_rope: bool = True,
31-
):
32-
config = model_config.pretrained_config
33-
34-
if getattr(config, "rope_scaling", None) is not None:
35-
pos_embd_params = PositionalEmbeddingParams(
36-
type=PositionEmbeddingType.from_string(
37-
config.rope_scaling["type"]),
38-
rope=RopeParams.from_config(config),
39-
)
40-
else:
41-
pos_embd_params = PositionalEmbeddingParams(
42-
type=PositionEmbeddingType.rope_gpt_neox,
43-
rope=RopeParams.from_config(config),
44-
)
45-
46-
self.fuse_qk_norm_rope = fuse_qk_norm_rope
47-
48-
super().__init__(
49-
hidden_size=config.hidden_size,
50-
num_attention_heads=config.num_attention_heads,
51-
num_key_value_heads=config.num_key_value_heads,
52-
max_position_embeddings=config.max_position_embeddings,
53-
bias=config.attention_bias,
54-
pos_embd_params=pos_embd_params,
55-
rope_fusion=not self.
56-
fuse_qk_norm_rope, # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope.
57-
layer_idx=layer_idx,
58-
dtype=config.torch_dtype,
59-
dense_bias=config.attention_bias,
60-
config=model_config,
61-
)
62-
63-
self.q_norm = RMSNorm(hidden_size=self.head_dim,
64-
eps=1e-6,
65-
dtype=config.torch_dtype,
66-
has_weights=True)
67-
self.k_norm = RMSNorm(hidden_size=self.head_dim,
68-
eps=1e-6,
69-
dtype=config.torch_dtype,
70-
has_weights=True)
71-
self.aux_stream = torch.cuda.Stream()
72-
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
73-
74-
def apply_qk_norm(self, q, k):
75-
76-
def q_l2norm():
77-
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
78-
-1, self.q_size)
79-
80-
def k_l2norm():
81-
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
82-
-1, self.kv_size)
83-
84-
q, k = maybe_execute_in_parallel(
85-
q_l2norm,
86-
k_l2norm,
87-
self.ln_events[0],
88-
self.ln_events[1],
89-
self.aux_stream,
90-
)
91-
92-
return q, k
93-
94-
def apply_qk_norm_rope(self, qkv, position_ids):
95-
torch.ops.trtllm.fused_qk_norm_rope(
96-
qkv, self.num_heads, self.num_key_value_heads,
97-
self.num_key_value_heads, self.head_dim,
98-
self.q_norm.variance_epsilon, self.q_norm.weight,
99-
self.k_norm.weight, self.pos_embd_params.rope.theta,
100-
self.pos_embd_params.is_neox, position_ids.view(-1))
101-
return qkv, None, None
102-
103-
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
104-
v: Optional[torch.Tensor], position_ids: torch.Tensor):
105-
# Qwen3 applies QK norm before RoPE.
106-
if not self.fuse_qk_norm_rope:
107-
q, k, v = self.split_qkv(q, k, v)
108-
q, k = self.apply_qk_norm(q, k)
109-
return super().apply_rope(q, k, v, position_ids)
110-
111-
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
112-
qkv = q
113-
return self.apply_qk_norm_rope(qkv, position_ids)
114-
115-
11620
class Qwen3DecoderLayer(DecoderLayer):
11721

11822
def __init__(
@@ -123,7 +27,7 @@ def __init__(
12327
super().__init__()
12428
self.layer_idx = layer_idx
12529
config = model_config.pretrained_config
126-
self.self_attn = Qwen3Attention(
30+
self.self_attn = QKNormRoPEAttention(
12731
model_config,
12832
layer_idx=layer_idx,
12933
)

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
RoutingMethodType, TRTLLMGenFusedMoE,
2121
create_moe)
2222
from ..modules.linear import TensorParallelMode
23+
from ..modules.qk_norm_attention import QKNormRoPEAttention
2324
from ..modules.rms_norm import RMSNorm
2425
from ..speculative import SpecMetadata
2526
from ..utils import AuxStreamType
26-
from .modeling_qwen3 import Qwen3Attention
2727
from .modeling_speculative import SpecDecOneEngineForCausalLM
2828
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
2929

@@ -166,7 +166,7 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
166166
super().__init__()
167167
self.model_config = model_config
168168
config = model_config.pretrained_config
169-
self.self_attn = Qwen3Attention(
169+
self.self_attn = QKNormRoPEAttention(
170170
model_config,
171171
layer_idx=layer_idx,
172172
)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from ..attention_backend.interface import (PositionalEmbeddingParams,
6+
PositionEmbeddingType, RopeParams)
7+
from ..model_config import ModelConfig
8+
from ..modules.attention import Attention
9+
from ..modules.multi_stream_utils import maybe_execute_in_parallel
10+
from ..modules.rms_norm import RMSNorm
11+
12+
13+
class QKNormRoPEAttention(Attention):
14+
"""
15+
QKNormRoPEAttention is a custom attention layer that applies QK norm and RoPE to the input tensor.
16+
It is used in the Qwen3 model.
17+
It is a subclass of Attention, and overrides the apply_rope method to apply QK norm and RoPE.
18+
"""
19+
20+
def __init__(
21+
self,
22+
model_config: ModelConfig,
23+
layer_idx: Optional[int] = None,
24+
fuse_qk_norm_rope: bool = True,
25+
):
26+
config = model_config.pretrained_config
27+
28+
if getattr(config, "rope_scaling", None) is not None:
29+
pos_embd_params = PositionalEmbeddingParams(
30+
type=PositionEmbeddingType.from_string(
31+
config.rope_scaling["type"]),
32+
rope=RopeParams.from_config(config),
33+
)
34+
else:
35+
pos_embd_params = PositionalEmbeddingParams(
36+
type=PositionEmbeddingType.rope_gpt_neox,
37+
rope=RopeParams.from_config(config),
38+
)
39+
40+
self.fuse_qk_norm_rope = fuse_qk_norm_rope
41+
42+
super().__init__(
43+
hidden_size=config.hidden_size,
44+
num_attention_heads=config.num_attention_heads,
45+
num_key_value_heads=config.num_key_value_heads,
46+
max_position_embeddings=config.max_position_embeddings,
47+
bias=config.attention_bias,
48+
pos_embd_params=pos_embd_params,
49+
rope_fusion=not self.
50+
fuse_qk_norm_rope, # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope.
51+
layer_idx=layer_idx,
52+
dtype=config.torch_dtype,
53+
dense_bias=config.attention_bias,
54+
config=model_config,
55+
)
56+
57+
self.q_norm = RMSNorm(hidden_size=self.head_dim,
58+
eps=1e-6,
59+
dtype=config.torch_dtype,
60+
has_weights=True)
61+
self.k_norm = RMSNorm(hidden_size=self.head_dim,
62+
eps=1e-6,
63+
dtype=config.torch_dtype,
64+
has_weights=True)
65+
self.aux_stream = torch.cuda.Stream()
66+
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
67+
68+
def apply_qk_norm(self, q, k):
69+
70+
def q_l2norm():
71+
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
72+
-1, self.q_size)
73+
74+
def k_l2norm():
75+
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
76+
-1, self.kv_size)
77+
78+
q, k = maybe_execute_in_parallel(
79+
q_l2norm,
80+
k_l2norm,
81+
self.ln_events[0],
82+
self.ln_events[1],
83+
self.aux_stream,
84+
)
85+
86+
return q, k
87+
88+
def apply_qk_norm_rope(self, qkv, position_ids):
89+
torch.ops.trtllm.fused_qk_norm_rope(
90+
qkv, self.num_heads, self.num_key_value_heads,
91+
self.num_key_value_heads, self.head_dim,
92+
self.q_norm.variance_epsilon, self.q_norm.weight,
93+
self.k_norm.weight, self.pos_embd_params.rope.theta,
94+
self.pos_embd_params.is_neox, position_ids.view(-1))
95+
return qkv, None, None
96+
97+
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
98+
v: Optional[torch.Tensor], position_ids: torch.Tensor):
99+
"""
100+
The apply_rope method is called in the forward method of the Attention class.
101+
The apply_rope method is overridden in this class to apply QK norm and RoPE to the input tensor.
102+
"""
103+
# Qwen3 applies QK norm before RoPE.
104+
if not self.fuse_qk_norm_rope:
105+
q, k, v = self.split_qkv(q, k, v)
106+
q, k = self.apply_qk_norm(q, k)
107+
return super().apply_rope(q, k, v, position_ids)
108+
109+
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
110+
qkv = q
111+
return self.apply_qk_norm_rope(qkv, position_ids)

0 commit comments

Comments
 (0)