@@ -1506,7 +1506,7 @@ def get_runner(
15061506 self .num_experts ,
15071507 self .top_k ,
15081508 )
1509- instance_key = (tile_tokens_dim , )
1509+ instance_key = (tile_tokens_dim , self . act_type )
15101510 if instance_key not in FP8FP4BlockScaleMoERunner .runner_dict :
15111511 FP8FP4BlockScaleMoERunner .runner_dict [
15121512 instance_key ] = torch .classes .trtllm .FP8FP4BlockScaleMoERunner (
@@ -1668,30 +1668,6 @@ def fp8_fp4_block_scale_moe_runner(
16681668 return kernel_runner (inputs , tactic = best_tactic )
16691669
16701670
1671- def fp8_fp4_block_scale_fake_output_without_finalize (
1672- hidden_states : Union [torch .Tensor , Fp4QuantizedTensor ],
1673- num_experts : int ,
1674- top_k : int ,
1675- routing_bias : Optional [torch .Tensor ],
1676- ):
1677- num_tokens = hidden_states .shape [0 ]
1678- hidden_size = hidden_states .shape [1 ]
1679-
1680- tile_tokens_dim = calculate_tile_tokens_dim (num_tokens , num_experts , top_k )
1681-
1682- expanded_row_count = num_tokens * top_k
1683- max_padding_required = (tile_tokens_dim - 1 ) * num_experts
1684- max_num_padded_tokens = fp4_utils .pad_up (
1685- expanded_row_count + max_padding_required , tile_tokens_dim )
1686- wt_dtype = routing_bias .dtype if routing_bias is not None else torch .bfloat16
1687- return [
1688- hidden_states .new_empty ((max_num_padded_tokens , hidden_size ),
1689- dtype = torch .bfloat16 ),
1690- hidden_states .new_empty ((num_tokens , top_k ), dtype = wt_dtype ),
1691- hidden_states .new_empty ((num_tokens , top_k ), dtype = torch .int32 )
1692- ]
1693-
1694-
16951671@fp8_fp4_block_scale_moe_runner .register_fake
16961672def _ (
16971673 routing_logits ,
@@ -1716,17 +1692,25 @@ def _(
17161692 do_finalize ,
17171693 act_type ,
17181694) -> List [torch .Tensor ]:
1695+
1696+ num_tokens = hidden_states .shape [0 ]
1697+ hidden_size = hidden_states .shape [1 ]
1698+
17191699 if do_finalize :
1720- num_tokens = hidden_states .shape [0 ]
1721- hidden_size = hidden_states .shape [1 ]
17221700 return [
17231701 hidden_states .new_empty ((num_tokens , hidden_size ),
17241702 dtype = torch .bfloat16 )
17251703 ]
17261704
1727- return fp8_fp4_block_scale_fake_output_without_finalize (
1728- hidden_states ,
1729- num_experts ,
1730- top_k ,
1731- routing_bias ,
1732- )
1705+ tile_tokens_dim = calculate_tile_tokens_dim (num_tokens , num_experts , top_k )
1706+ expanded_row_count = num_tokens * top_k
1707+ max_padding_required = (tile_tokens_dim - 1 ) * num_experts
1708+ max_num_padded_tokens = fp4_utils .pad_up (
1709+ expanded_row_count + max_padding_required , tile_tokens_dim )
1710+ wt_dtype = routing_bias .dtype if routing_bias is not None else torch .bfloat16
1711+ return [
1712+ hidden_states .new_empty ((max_num_padded_tokens , hidden_size ),
1713+ dtype = torch .bfloat16 ),
1714+ hidden_states .new_empty ((num_tokens , top_k ), dtype = wt_dtype ),
1715+ hidden_states .new_empty ((num_tokens , top_k ), dtype = torch .int32 )
1716+ ]
0 commit comments