@@ -30,7 +30,7 @@ using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Run
3030
3131std::vector<torch::Tensor> run_fp4_block_scale_moe_runner (torch::Tensor const & routing_logits,
3232 torch::optional<torch::Tensor> const & routing_bias, torch::Tensor const & hidden_states,
33- torch::Tensor const & hidden_states_scale, torch::Tensor const & gemm1_weights,
33+ torch::optional<torch:: Tensor> const & hidden_states_scale, torch::Tensor const & gemm1_weights,
3434 torch::Tensor const & gemm1_weights_scale, torch::Tensor const & gemm2_weights,
3535 torch::Tensor const & gemm2_weights_scale, torch::Tensor const & output1_scales_scalar,
3636 torch::Tensor const & output1_scales_gate_scalar, torch::Tensor const & output2_scales_scalar,
@@ -39,6 +39,7 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
3939 int64_t const local_num_experts, std::optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim,
4040 int64_t const routing_method_type, bool const do_finalize, MoeRunnerType& moe_runner, int64_t const moeConfigIndex)
4141{
42+ bool const isFp8Fp4 = !hidden_states_scale.has_value ();
4243 TORCH_CHECK (tensorrt_llm::common::isSM100Family (), " Only SM100f is supported by FP4 block scale MOE" );
4344 TORCH_CHECK (tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64 ,
4445 " tile_tokens_dim must be 8, 16, 32, 64" );
@@ -102,15 +103,22 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
102103 args.routing_logits = routing_logits.data_ptr ();
103104 args.routing_bias = routing_bias.has_value () ? routing_bias.value ().data_ptr () : nullptr ;
104105 args.hidden_states = hidden_states.data_ptr ();
105- args.hidden_states_scale = hidden_states_scale.data_ptr ();
106+ args.hidden_states_scale = hidden_states_scale.has_value () ? hidden_states_scale. value (). data_ptr () : nullptr ;
106107 args.gemm1_weights = gemm1_weights.data_ptr ();
107108 args.gemm1_weights_scale = gemm1_weights_scale.data_ptr ();
108109 args.gemm2_weights = gemm2_weights.data_ptr ();
109110 args.gemm2_weights_scale = gemm2_weights_scale.data_ptr ();
110111 args.num_tokens = hidden_states.sizes ()[0 ];
111112 args.num_experts = num_experts;
112113 // * 2 to compensate for the fact that sizeof(hidden_states.dtype) is 1 because we pack 2 e2m1 into 1 byte.
113- args.hidden_size = hidden_states.sizes ()[1 ] * 2 ;
114+ if (isFp8Fp4)
115+ {
116+ args.hidden_size = hidden_states.sizes ()[1 ];
117+ }
118+ else
119+ {
120+ args.hidden_size = hidden_states.sizes ()[1 ] * 2 ;
121+ }
114122 args.top_k = top_k;
115123 args.n_group = n_group.value_or (0 );
116124 args.topk_group = topk_group.value_or (0 );
@@ -180,22 +188,25 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
180188 // FC13 (gemm1) + FC2 (gemm2)
181189 //
182190
183- TORCH_CHECK (hidden_states.scalar_type () == FLOAT4_E2M1X2, " hidden_states must be byte." );
184- TORCH_CHECK (hidden_states_scale.scalar_type () == at::ScalarType::Float8_e4m3fn, " hidden_states_scale must be fp8." );
185-
186- TORCH_CHECK (hidden_states_scale.dim () == 1 , " hidden_states_scale must be 1D." );
187- TORCH_CHECK (hidden_states_scale.sizes ()[0 ]
188- == tensorrt_llm::computeLinearLayoutSFSize (args.num_tokens , args.hidden_size / 16 ),
189- " hidden_states_scale has incorrect size" );
190-
191+ if (!isFp8Fp4)
192+ {
193+ TORCH_CHECK (hidden_states.scalar_type () == FLOAT4_E2M1X2, " hidden_states must be byte." );
194+ TORCH_CHECK (hidden_states_scale.value ().scalar_type () == at::ScalarType::Float8_e4m3fn,
195+ " hidden_states_scale must be fp8." );
196+
197+ TORCH_CHECK (hidden_states_scale.value ().dim () == 1 , " hidden_states_scale must be 1D." );
198+ TORCH_CHECK (hidden_states_scale.value ().sizes ()[0 ]
199+ == tensorrt_llm::computeLinearLayoutSFSize (args.num_tokens , args.hidden_size / 16 ),
200+ " hidden_states_scale has incorrect size" );
201+ }
191202 TORCH_CHECK (gemm1_weights.scalar_type () == FLOAT4_E2M1X2, " gemm1_weights must be byte." );
192203
193204 TORCH_CHECK (gemm1_weights.dim () == 3 , " gemm1_weights must be 3D." );
194205 TORCH_CHECK (gemm1_weights.sizes ()[1 ] % 2 == 0 , " the second dimension of weights must be even." );
195206 TORCH_CHECK (intermediate_size == gemm1_weights.sizes ()[1 ] / 2 , " intermediate_size has incorrect dim 1." );
196207 // This check passes even though the actual shape of the weights[2] and hidden_states[1] is
197208 // 2 times larger due to the fact that 2 e2m1 are packed into 1 byte.
198- TORCH_CHECK (gemm1_weights.sizes ()[2 ] == hidden_states.sizes ()[1 ],
209+ TORCH_CHECK (gemm1_weights.sizes ()[2 ] * 2 == hidden_states.sizes ()[1 ],
199210 " the third dimension of weights must be equal to hidden_size." );
200211
201212 TORCH_CHECK (gemm1_weights_scale.scalar_type () == at::ScalarType::Float8_e4m3fn, " gemm1_weights_scale must be fp8." );
@@ -204,7 +215,16 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
204215 TORCH_CHECK (gemm1_weights_scale.sizes ()[0 ] == local_num_experts, " gemm1_weights_scale has incorrect dim 0." );
205216 TORCH_CHECK (intermediate_size % 16 == 0 , " the second dimension of weights must be a multiple of 16." );
206217 TORCH_CHECK (gemm1_weights_scale.sizes ()[1 ] == 2 * intermediate_size, " gemm1_weights_scale has incorrect dim 1." );
207- TORCH_CHECK (gemm1_weights_scale.sizes ()[2 ] == args.hidden_size / 16 , " gemm1_weights_scale has incorrect dim 2." );
218+ if (isFp8Fp4)
219+ {
220+ TORCH_CHECK (
221+ gemm1_weights_scale.sizes ()[2 ] == args.hidden_size / 32 , " gemm1_weights_scale has incorrect dim 2." );
222+ }
223+ else
224+ {
225+ TORCH_CHECK (
226+ gemm1_weights_scale.sizes ()[2 ] == args.hidden_size / 16 , " gemm1_weights_scale has incorrect dim 2." );
227+ }
208228
209229 TORCH_CHECK (gemm2_weights.scalar_type () == FLOAT4_E2M1X2, " gemm2_weights must be byte." );
210230
@@ -218,7 +238,16 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::Tensor const& r
218238 TORCH_CHECK (gemm2_weights_scale.dim () == 3 , " gemm2_weights_scale must be 3D." );
219239 TORCH_CHECK (gemm2_weights_scale.sizes ()[0 ] == local_num_experts, " gemm2_weights_scale has incorrect dim 0." );
220240 TORCH_CHECK (gemm2_weights_scale.sizes ()[1 ] == args.hidden_size , " gemm2_weights_scale has incorrect dim 1." );
221- TORCH_CHECK (gemm2_weights_scale.sizes ()[2 ] == intermediate_size / 16 , " gemm2_weights_scale has incorrect dim 2." );
241+ if (isFp8Fp4)
242+ {
243+ TORCH_CHECK (
244+ gemm2_weights_scale.sizes ()[2 ] == intermediate_size / 32 , " gemm2_weights_scale has incorrect dim 2." );
245+ }
246+ else
247+ {
248+ TORCH_CHECK (
249+ gemm2_weights_scale.sizes ()[2 ] == intermediate_size / 16 , " gemm2_weights_scale has incorrect dim 2." );
250+ }
222251
223252 TORCH_CHECK (output1_scales_scalar.scalar_type () == at::ScalarType::Float, " output1_scales_scalar must be float." );
224253 TORCH_CHECK (output1_scales_scalar.dim () == 1 , " output1_scales_scalar must be 1D." );
@@ -343,6 +372,65 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
343372 int64_t mTileTokensDim ;
344373};
345374
375+ // Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow
376+ // use with the torch workflow autotuner class.
377+ class FP8FP4BlockScaleMoeRunner : public torch ::CustomClassHolder
378+ {
379+ public:
380+ explicit FP8FP4BlockScaleMoeRunner (int64_t tileTokensDim, int64_t actType)
381+ : mTileTokensDim(tileTokensDim)
382+ {
383+ mRunner = std::make_unique<RunnerType>(mDtypeAct , mDtypeWeights , mUseDeepSeekFp8 , mTileTokensDim ,
384+ static_cast <tensorrt_llm::kernels::ActType>(actType));
385+ }
386+
387+ [[nodiscard]] std::vector<torch::Tensor> run (torch::Tensor const & routing_logits,
388+ torch::optional<torch::Tensor> const & routing_bias, torch::Tensor const & hidden_states,
389+ torch::Tensor const & gemm1_weights, torch::Tensor const & gemm1_weights_scale,
390+ torch::Tensor const & gemm2_weights, torch::Tensor const & gemm2_weights_scale,
391+ torch::Tensor const & output1_scales_scalar, torch::Tensor const & output1_scales_gate_scalar,
392+ torch::Tensor const & output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
393+ std::optional<int64_t > const n_group, std::optional<int64_t > const topk_group, int64_t const intermediate_size,
394+ int64_t const local_expert_offset, int64_t const local_num_experts,
395+ std::optional<double > const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
396+ int64_t moeConfigIndex)
397+ {
398+
399+ // Autotuner has requested a default or 'fallback' config index
400+ if (moeConfigIndex == -1 )
401+ {
402+ auto const num_tokens = hidden_states.sizes ()[0 ];
403+
404+ // 2x FP4 per byte element
405+ auto const hidden_size = 2 * hidden_states.sizes ()[1 ];
406+
407+ moeConfigIndex = mRunner ->getDefaultValidConfigIndex (
408+ top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
409+ }
410+
411+ return run_fp4_block_scale_moe_runner (routing_logits, routing_bias, hidden_states,
412+ std::nullopt /* hidden_states_scale*/ , gemm1_weights, gemm1_weights_scale, gemm2_weights,
413+ gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts,
414+ top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
415+ routed_scaling_factor, mTileTokensDim , routing_method_type, do_finalize, *mRunner , moeConfigIndex);
416+ }
417+
418+ [[nodiscard]] std::vector<int64_t > getValidConfigs (
419+ int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens) const
420+ {
421+ return mRunner ->getValidConfigIndices (topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
422+ }
423+
424+ private:
425+ using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
426+
427+ std::unique_ptr<RunnerType> mRunner ;
428+ btg::Dtype mDtypeAct {btg::Dtype::E4m3};
429+ btg::Dtype mDtypeWeights {btg::Dtype::E2m1};
430+ bool mUseDeepSeekFp8 {false };
431+ int64_t mTileTokensDim ;
432+ };
433+
346434torch::Tensor shuffleMatrix (torch::Tensor matrix, torch::Tensor permuteIndices)
347435{
348436 return torch::index_select (matrix, 0 , permuteIndices);
@@ -356,6 +444,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
356444 .def (torch::init<int64_t >())
357445 .def (" get_valid_configs" , &torch_ext::FP4BlockScaleMoeRunner::getValidConfigs)
358446 .def (" run_moe" , &torch_ext::FP4BlockScaleMoeRunner::run);
447+ m.class_ <torch_ext::FP8FP4BlockScaleMoeRunner>(" FP8FP4BlockScaleMoERunner" )
448+ .def (torch::init<int64_t , int64_t >())
449+ .def (" get_valid_configs" , &torch_ext::FP8FP4BlockScaleMoeRunner::getValidConfigs)
450+ .def (" run_moe" , &torch_ext::FP8FP4BlockScaleMoeRunner::run);
359451}
360452
361453// Accepts both CPU and CUDA tensors
0 commit comments