88
99from ..attention_backend import AttentionMetadata
1010from ..attention_backend .interface import PositionalEmbeddingParams , RopeParams
11+ from ..distributed import AllReduceParams
1112from ..model_config import ModelConfig
1213from ..modules .decoder_layer import DecoderLayer
1314from ..modules .embedding import Embedding
@@ -82,6 +83,8 @@ def __init__(
8283 model_config ,
8384 layer_idx = layer_idx ,
8485 )
86+ self .mapping = model_config .mapping
87+ self .enable_attention_dp = self .mapping .enable_attention_dp
8588
8689 # Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
8790 # and https://nvbugspro.nvidia.com/bug/5505402)
@@ -92,6 +95,7 @@ def __init__(
9295 intermediate_size = config .intermediate_size ,
9396 bias = config .mlp_bias if hasattr (config , "mlp_bias" ) else False ,
9497 dtype = config .torch_dtype ,
98+ overridden_tp_size = 1 if self .enable_attention_dp else None ,
9599 config = model_config ,
96100 disable_deep_gemm = disable_deep_gemm ,
97101 )
@@ -102,6 +106,8 @@ def __init__(
102106 self .post_attention_layernorm = RMSNorm (hidden_size = config .hidden_size ,
103107 eps = config .rms_norm_eps ,
104108 dtype = config .torch_dtype )
109+ self .disable_allreduce = (self .mapping .tp_size == 1
110+ or self .enable_attention_dp )
105111
106112 def forward (
107113 self ,
@@ -126,13 +132,22 @@ def forward(
126132 hidden_states = hidden_states ,
127133 attn_metadata = attn_metadata ,
128134 mrope_config = mrope_config ,
135+ all_reduce_params = AllReduceParams (
136+ enable_allreduce = not self .disable_allreduce ),
129137 ** kwargs ,
130138 )
131139
132140 # Fully Connected
133141 hidden_states , residual = self .post_attention_layernorm (
134142 hidden_states , residual )
135- hidden_states = self .mlp (hidden_states )
143+ hidden_states = self .mlp (
144+ hidden_states ,
145+ all_rank_num_tokens = attn_metadata .all_rank_num_tokens ,
146+ all_rank_max_num_tokens = attn_metadata .all_rank_max_num_tokens ,
147+ final_all_reduce_params = AllReduceParams (
148+ enable_allreduce = not self .disable_allreduce ),
149+ cutlass_min_latency_mode = False ,
150+ )
136151
137152 if spec_metadata is not None :
138153 spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
0 commit comments