Skip to content

Commit e2ccf10

Browse files
committed
expose the kernel to python
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent 5b7e15a commit e2ccf10

File tree

3 files changed

+661
-23
lines changed

3 files changed

+661
-23
lines changed

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Run
3030

3131
std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& routing_logits,
3232
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
33-
torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights,
33+
torch::optional<torch::Tensor> const& hidden_states_scale, torch::Tensor const& gemm1_weights,
3434
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
3535
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
3636
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
@@ -39,6 +39,7 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
3939
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
4040
int64_t const routing_method_type, bool const do_finalize, MoeRunnerType& moe_runner, int64_t const moeConfigIndex)
4141
{
42+
bool const isFp8Fp4 = !hidden_states_scale.has_value();
4243
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE");
4344
TORCH_CHECK(tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64,
4445
"tile_tokens_dim must be 8, 16, 32, 64");
@@ -102,15 +103,22 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
102103
args.routing_logits = routing_logits.data_ptr();
103104
args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr;
104105
args.hidden_states = hidden_states.data_ptr();
105-
args.hidden_states_scale = hidden_states_scale.data_ptr();
106+
args.hidden_states_scale = hidden_states_scale.has_value() ? hidden_states_scale.value().data_ptr() : nullptr;
106107
args.gemm1_weights = gemm1_weights.data_ptr();
107108
args.gemm1_weights_scale = gemm1_weights_scale.data_ptr();
108109
args.gemm2_weights = gemm2_weights.data_ptr();
109110
args.gemm2_weights_scale = gemm2_weights_scale.data_ptr();
110111
args.num_tokens = hidden_states.sizes()[0];
111112
args.num_experts = num_experts;
112113
// * 2 to compensate for the fact that sizeof(hidden_states.dtype) is 1 because we pack 2 e2m1 into 1 byte.
113-
args.hidden_size = hidden_states.sizes()[1] * 2;
114+
if (isFp8Fp4)
115+
{
116+
args.hidden_size = hidden_states.sizes()[1];
117+
}
118+
else
119+
{
120+
args.hidden_size = hidden_states.sizes()[1] * 2;
121+
}
114122
args.top_k = top_k;
115123
args.n_group = n_group.value_or(0);
116124
args.topk_group = topk_group.value_or(0);
@@ -180,22 +188,25 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
180188
// FC13 (gemm1) + FC2 (gemm2)
181189
//
182190

183-
TORCH_CHECK(hidden_states.scalar_type() == FLOAT4_E2M1X2, "hidden_states must be byte.");
184-
TORCH_CHECK(hidden_states_scale.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states_scale must be fp8.");
185-
186-
TORCH_CHECK(hidden_states_scale.dim() == 1, "hidden_states_scale must be 1D.");
187-
TORCH_CHECK(hidden_states_scale.sizes()[0]
188-
== tensorrt_llm::computeLinearLayoutSFSize(args.num_tokens, args.hidden_size / 16),
189-
"hidden_states_scale has incorrect size");
190-
191+
if (!isFp8Fp4)
192+
{
193+
TORCH_CHECK(hidden_states.scalar_type() == FLOAT4_E2M1X2, "hidden_states must be byte.");
194+
TORCH_CHECK(hidden_states_scale.value().scalar_type() == at::ScalarType::Float8_e4m3fn,
195+
"hidden_states_scale must be fp8.");
196+
197+
TORCH_CHECK(hidden_states_scale.value().dim() == 1, "hidden_states_scale must be 1D.");
198+
TORCH_CHECK(hidden_states_scale.value().sizes()[0]
199+
== tensorrt_llm::computeLinearLayoutSFSize(args.num_tokens, args.hidden_size / 16),
200+
"hidden_states_scale has incorrect size");
201+
}
191202
TORCH_CHECK(gemm1_weights.scalar_type() == FLOAT4_E2M1X2, "gemm1_weights must be byte.");
192203

193204
TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D.");
194205
TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even.");
195206
TORCH_CHECK(intermediate_size == gemm1_weights.sizes()[1] / 2, "intermediate_size has incorrect dim 1.");
196207
// This check passes even though the actual shape of the weights[2] and hidden_states[1] is
197208
// 2 times larger due to the fact that 2 e2m1 are packed into 1 byte.
198-
TORCH_CHECK(gemm1_weights.sizes()[2] == hidden_states.sizes()[1],
209+
TORCH_CHECK(gemm1_weights.sizes()[2] * 2 == hidden_states.sizes()[1],
199210
"the third dimension of weights must be equal to hidden_size.");
200211

201212
TORCH_CHECK(gemm1_weights_scale.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm1_weights_scale must be fp8.");
@@ -204,7 +215,16 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
204215
TORCH_CHECK(gemm1_weights_scale.sizes()[0] == local_num_experts, "gemm1_weights_scale has incorrect dim 0.");
205216
TORCH_CHECK(intermediate_size % 16 == 0, "the second dimension of weights must be a multiple of 16.");
206217
TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size, "gemm1_weights_scale has incorrect dim 1.");
207-
TORCH_CHECK(gemm1_weights_scale.sizes()[2] == args.hidden_size / 16, "gemm1_weights_scale has incorrect dim 2.");
218+
if (isFp8Fp4)
219+
{
220+
TORCH_CHECK(
221+
gemm1_weights_scale.sizes()[2] == args.hidden_size / 32, "gemm1_weights_scale has incorrect dim 2.");
222+
}
223+
else
224+
{
225+
TORCH_CHECK(
226+
gemm1_weights_scale.sizes()[2] == args.hidden_size / 16, "gemm1_weights_scale has incorrect dim 2.");
227+
}
208228

209229
TORCH_CHECK(gemm2_weights.scalar_type() == FLOAT4_E2M1X2, "gemm2_weights must be byte.");
210230

@@ -218,7 +238,16 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
218238
TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
219239
TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts, "gemm2_weights_scale has incorrect dim 0.");
220240
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size, "gemm2_weights_scale has incorrect dim 1.");
221-
TORCH_CHECK(gemm2_weights_scale.sizes()[2] == intermediate_size / 16, "gemm2_weights_scale has incorrect dim 2.");
241+
if (isFp8Fp4)
242+
{
243+
TORCH_CHECK(
244+
gemm2_weights_scale.sizes()[2] == intermediate_size / 32, "gemm2_weights_scale has incorrect dim 2.");
245+
}
246+
else
247+
{
248+
TORCH_CHECK(
249+
gemm2_weights_scale.sizes()[2] == intermediate_size / 16, "gemm2_weights_scale has incorrect dim 2.");
250+
}
222251

223252
TORCH_CHECK(output1_scales_scalar.scalar_type() == at::ScalarType::Float, "output1_scales_scalar must be float.");
224253
TORCH_CHECK(output1_scales_scalar.dim() == 1, "output1_scales_scalar must be 1D.");
@@ -343,6 +372,65 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
343372
int64_t mTileTokensDim;
344373
};
345374

375+
// Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow
376+
// use with the torch workflow autotuner class.
377+
class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
378+
{
379+
public:
380+
explicit FP8FP4BlockScaleMoeRunner(int64_t tileTokensDim, int64_t actType)
381+
: mTileTokensDim(tileTokensDim)
382+
{
383+
mRunner = std::make_unique<RunnerType>(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, mTileTokensDim,
384+
static_cast<tensorrt_llm::kernels::ActType>(actType));
385+
}
386+
387+
[[nodiscard]] std::vector<torch::Tensor> run(torch::Tensor const& routing_logits,
388+
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
389+
torch::Tensor const& gemm1_weights, torch::Tensor const& gemm1_weights_scale,
390+
torch::Tensor const& gemm2_weights, torch::Tensor const& gemm2_weights_scale,
391+
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
392+
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
393+
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
394+
int64_t const local_expert_offset, int64_t const local_num_experts,
395+
std::optional<double> const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
396+
int64_t moeConfigIndex)
397+
{
398+
399+
// Autotuner has requested a default or 'fallback' config index
400+
if (moeConfigIndex == -1)
401+
{
402+
auto const num_tokens = hidden_states.sizes()[0];
403+
404+
// 2x FP4 per byte element
405+
auto const hidden_size = 2 * hidden_states.sizes()[1];
406+
407+
moeConfigIndex = mRunner->getDefaultValidConfigIndex(
408+
top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
409+
}
410+
411+
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states,
412+
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, gemm2_weights,
413+
gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts,
414+
top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
415+
routed_scaling_factor, mTileTokensDim, routing_method_type, do_finalize, *mRunner, moeConfigIndex);
416+
}
417+
418+
[[nodiscard]] std::vector<int64_t> getValidConfigs(
419+
int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens) const
420+
{
421+
return mRunner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
422+
}
423+
424+
private:
425+
using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
426+
427+
std::unique_ptr<RunnerType> mRunner;
428+
btg::Dtype mDtypeAct{btg::Dtype::E4m3};
429+
btg::Dtype mDtypeWeights{btg::Dtype::E2m1};
430+
bool mUseDeepSeekFp8{false};
431+
int64_t mTileTokensDim;
432+
};
433+
346434
torch::Tensor shuffleMatrix(torch::Tensor matrix, torch::Tensor permuteIndices)
347435
{
348436
return torch::index_select(matrix, 0, permuteIndices);
@@ -356,6 +444,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
356444
.def(torch::init<int64_t>())
357445
.def("get_valid_configs", &torch_ext::FP4BlockScaleMoeRunner::getValidConfigs)
358446
.def("run_moe", &torch_ext::FP4BlockScaleMoeRunner::run);
447+
m.class_<torch_ext::FP8FP4BlockScaleMoeRunner>("FP8FP4BlockScaleMoERunner")
448+
.def(torch::init<int64_t, int64_t>())
449+
.def("get_valid_configs", &torch_ext::FP8FP4BlockScaleMoeRunner::getValidConfigs)
450+
.def("run_moe", &torch_ext::FP8FP4BlockScaleMoeRunner::run);
359451
}
360452

361453
// Accepts both CPU and CUDA tensors

0 commit comments

Comments
 (0)