Skip to content

Commit c3266c7

Browse files
committed
improve based on comments
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent 5b8a7ce commit c3266c7

File tree

1 file changed

+17
-33
lines changed

1 file changed

+17
-33
lines changed

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def get_runner(
15061506
self.num_experts,
15071507
self.top_k,
15081508
)
1509-
instance_key = (tile_tokens_dim, )
1509+
instance_key = (tile_tokens_dim, self.act_type)
15101510
if instance_key not in FP8FP4BlockScaleMoERunner.runner_dict:
15111511
FP8FP4BlockScaleMoERunner.runner_dict[
15121512
instance_key] = torch.classes.trtllm.FP8FP4BlockScaleMoERunner(
@@ -1668,30 +1668,6 @@ def fp8_fp4_block_scale_moe_runner(
16681668
return kernel_runner(inputs, tactic=best_tactic)
16691669

16701670

1671-
def fp8_fp4_block_scale_fake_output_without_finalize(
1672-
hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
1673-
num_experts: int,
1674-
top_k: int,
1675-
routing_bias: Optional[torch.Tensor],
1676-
):
1677-
num_tokens = hidden_states.shape[0]
1678-
hidden_size = hidden_states.shape[1]
1679-
1680-
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
1681-
1682-
expanded_row_count = num_tokens * top_k
1683-
max_padding_required = (tile_tokens_dim - 1) * num_experts
1684-
max_num_padded_tokens = fp4_utils.pad_up(
1685-
expanded_row_count + max_padding_required, tile_tokens_dim)
1686-
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
1687-
return [
1688-
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
1689-
dtype=torch.bfloat16),
1690-
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
1691-
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
1692-
]
1693-
1694-
16951671
@fp8_fp4_block_scale_moe_runner.register_fake
16961672
def _(
16971673
routing_logits,
@@ -1716,17 +1692,25 @@ def _(
17161692
do_finalize,
17171693
act_type,
17181694
) -> List[torch.Tensor]:
1695+
1696+
num_tokens = hidden_states.shape[0]
1697+
hidden_size = hidden_states.shape[1]
1698+
17191699
if do_finalize:
1720-
num_tokens = hidden_states.shape[0]
1721-
hidden_size = hidden_states.shape[1]
17221700
return [
17231701
hidden_states.new_empty((num_tokens, hidden_size),
17241702
dtype=torch.bfloat16)
17251703
]
17261704

1727-
return fp8_fp4_block_scale_fake_output_without_finalize(
1728-
hidden_states,
1729-
num_experts,
1730-
top_k,
1731-
routing_bias,
1732-
)
1705+
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
1706+
expanded_row_count = num_tokens * top_k
1707+
max_padding_required = (tile_tokens_dim - 1) * num_experts
1708+
max_num_padded_tokens = fp4_utils.pad_up(
1709+
expanded_row_count + max_padding_required, tile_tokens_dim)
1710+
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
1711+
return [
1712+
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
1713+
dtype=torch.bfloat16),
1714+
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
1715+
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
1716+
]

0 commit comments

Comments
 (0)