Skip to content

Commit d28167f

Browse files
syuonicodego7250
authored andcommitted
[TRTLLM-9372][feat] Enable CuteDSL MoE with Large EP (NVIDIA#9592)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent 28a2838 commit d28167f

File tree

19 files changed

+737
-359
lines changed

19 files changed

+737
-359
lines changed

cpp/include/tensorrt_llm/common/cudaUtils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <optional>
3939
#include <sstream>
4040
#include <string>
41+
#include <unordered_map>
4142
#ifndef _WIN32 // Linux
4243
#include <sys/sysinfo.h>
4344
#endif // not WIN32
@@ -432,6 +433,21 @@ inline int getMaxSharedMemoryPerBlockOptin()
432433
return nByteMaxSharedMemoryPerBlockOptin;
433434
}
434435

436+
template <typename T>
437+
inline int getMaxActiveBlocksPerSM(T kernel, int blockSize, size_t dynamicSMemSize)
438+
{
439+
static std::unordered_map<T, int> cache;
440+
auto it = cache.find(kernel);
441+
if (it != cache.end())
442+
{
443+
return it->second;
444+
}
445+
int numBlocks;
446+
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, blockSize, dynamicSMemSize));
447+
cache[kernel] = numBlocks;
448+
return numBlocks;
449+
}
450+
435451
template <typename T1, typename T2>
436452
inline size_t divUp(T1 const& a, T2 const& b)
437453
{

cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu

Lines changed: 113 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
6767
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
6868

6969
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
70-
asm volatile("griddepcontrol.wait;");
70+
cudaGridDependencySynchronize();
7171
#endif
7272

7373
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
@@ -110,7 +110,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
110110
}
111111

112112
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
113-
asm volatile("griddepcontrol.launch_dependents;");
113+
cudaTriggerProgrammaticLaunchCompletion();
114114
#endif
115115
}
116116

@@ -141,12 +141,12 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const
141141
}
142142
#endif
143143

144+
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
144145
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
145-
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
146+
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
147+
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
146148
int32_t const threads = kThreadsPerBlock;
147149

148-
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
149-
150150
cudaLaunchConfig_t config;
151151
config.gridDim = blocks;
152152
config.blockDim = threads;
@@ -195,7 +195,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
195195
int32_t const token_idx = blockIdx.x;
196196

197197
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
198-
asm volatile("griddepcontrol.wait;");
198+
cudaGridDependencySynchronize();
199199
#endif
200200

201201
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(output) + token_idx * kCopyPerToken;
@@ -232,7 +232,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
232232
}
233233

234234
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
235-
asm volatile("griddepcontrol.launch_dependents;");
235+
cudaTriggerProgrammaticLaunchCompletion();
236236
#endif
237237
}
238238

@@ -277,6 +277,105 @@ INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16);
277277
#endif
278278
#undef INSTANTIATE_MOE_UNPERMUTE
279279

280+
template <typename InputType, int32_t kThreadsPerBlock>
281+
__global__ void moeOutputMemsetKernel(InputType* input, int32_t const* tile_idx_to_mn_limit,
282+
int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx,
283+
int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size)
284+
{
285+
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
286+
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
287+
288+
InputType rmem[kElemPerCopy];
289+
#pragma unroll
290+
for (int32_t j = 0; j < kElemPerCopy; j++)
291+
{
292+
rmem[j] = 0;
293+
}
294+
295+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
296+
cudaGridDependencySynchronize();
297+
#endif
298+
299+
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
300+
for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x)
301+
{
302+
int32_t const tile_idx = permuted_idx / tile_size;
303+
if (permuted_idx >= tile_idx_to_mn_limit[tile_idx])
304+
{
305+
continue;
306+
}
307+
int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx];
308+
int32_t const token_idx = expanded_idx / top_k;
309+
int32_t const topk_idx = expanded_idx % top_k;
310+
311+
bool is_first_in_topk = true;
312+
for (int32_t k = 0; k < topk_idx; k++)
313+
{
314+
if (expanded_idx_to_permuted_idx[token_idx * top_k + k] >= 0)
315+
{
316+
is_first_in_topk = false;
317+
break;
318+
}
319+
}
320+
if (!is_first_in_topk)
321+
{
322+
continue;
323+
}
324+
325+
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(input) + token_idx * kCopyPerToken;
326+
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
327+
{
328+
dst_ptr[i] = *reinterpret_cast<ElemCopyType*>(rmem);
329+
}
330+
}
331+
332+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
333+
cudaTriggerProgrammaticLaunchCompletion();
334+
#endif
335+
}
336+
337+
template <typename InputType>
338+
void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx,
339+
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles,
340+
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size,
341+
cudaStream_t stream)
342+
{
343+
int32_t constexpr kThreadsPerBlock = 256;
344+
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
345+
TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy);
346+
347+
auto kernel = &moeOutputMemsetKernel<InputType, kThreadsPerBlock>;
348+
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
349+
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
350+
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
351+
int32_t const threads = kThreadsPerBlock;
352+
353+
cudaLaunchConfig_t config;
354+
config.gridDim = blocks;
355+
config.blockDim = threads;
356+
config.dynamicSmemBytes = 0;
357+
config.stream = stream;
358+
cudaLaunchAttribute attrs[1];
359+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
360+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
361+
config.numAttrs = 1;
362+
config.attrs = attrs;
363+
cudaLaunchKernelEx(&config, kernel, input, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx,
364+
permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size);
365+
}
366+
367+
#define INSTANTIATE_MOE_OUTPUT_MEMSET(InputType) \
368+
template void moeOutputMemset<InputType>(InputType * input, int32_t const* tile_idx_to_mn_limit, \
369+
int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, \
370+
int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, \
371+
int32_t const top_k, int32_t const tile_size, cudaStream_t stream)
372+
373+
INSTANTIATE_MOE_OUTPUT_MEMSET(half);
374+
#ifdef ENABLE_BF16
375+
INSTANTIATE_MOE_OUTPUT_MEMSET(__nv_bfloat16);
376+
#endif
377+
#undef INSTANTIATE_MOE_OUTPUT_MEMSET
378+
280379
template <typename InputType, typename OutputType, typename SFType, int32_t kSFVecSize, typename ActFn,
281380
int32_t kThreadsPerBlock>
282381
__global__ void moeActivationKernel(InputType const* input, OutputType* output, float const* global_sf,
@@ -297,7 +396,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
297396
ActFn act{};
298397

299398
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
300-
asm volatile("griddepcontrol.wait;");
399+
cudaGridDependencySynchronize();
301400
#endif
302401

303402
float global_sf_val = global_sf == nullptr ? 1.0f : global_sf[0];
@@ -353,7 +452,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
353452
}
354453

355454
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
356-
asm volatile("griddepcontrol.launch_dependents;");
455+
cudaTriggerProgrammaticLaunchCompletion();
357456
#endif
358457
}
359458

@@ -382,10 +481,6 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
382481
}
383482
#endif
384483

385-
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
386-
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
387-
int32_t const threads = kThreadsPerBlock;
388-
389484
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
390485
float const* global_sf, SFType* output_sf,
391486
int32_t const* tile_idx_to_mn_limit,
@@ -424,6 +519,11 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
424519
};
425520
auto kernel = get_act_kernel(activation_params.activation_type);
426521

522+
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
523+
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
524+
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
525+
int32_t const threads = kThreadsPerBlock;
526+
427527
cudaLaunchConfig_t config;
428528
config.gridDim = blocks;
429529
config.blockDim = threads;

cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t co
3232
TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k,
3333
cudaStream_t stream);
3434

35+
template <typename InputType>
36+
void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx,
37+
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles,
38+
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size,
39+
cudaStream_t stream);
40+
3541
template <typename InputType, typename OutputType, typename SFType>
3642
void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf,
3743
int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles,

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,11 +1587,6 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
15871587
int64_t num_padding_tokens = 0;
15881588
#endif
15891589

1590-
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
1591-
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1592-
int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens));
1593-
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
1594-
15951590
auto func = [&]()
15961591
{
15971592
#ifdef ENABLE_FP8
@@ -1637,6 +1632,12 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
16371632
}
16381633
}();
16391634

1635+
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
1636+
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(func, EXPAND_THREADS_PER_BLOCK, 0);
1637+
int32_t const blocks
1638+
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(num_rows * k, num_padding_tokens)));
1639+
int32_t const threads = EXPAND_THREADS_PER_BLOCK;
1640+
16401641
cudaLaunchConfig_t config;
16411642
config.gridDim = blocks;
16421643
config.blockDim = threads;
@@ -1891,15 +1892,18 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
18911892
if (parallelism_config.ep_size > 1 && enable_alltoall)
18921893
{
18931894
// If all-to-all comm is enabled, finalizeMoeRouting doesn't need to fill the invalid output tokens with zeros.
1894-
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
1895-
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1896-
int64_t const blocks = smCount * 8;
1897-
int64_t const threads = FINALIZE_THREADS_PER_BLOCK;
1898-
config.gridDim = blocks;
1899-
config.blockDim = threads;
19001895
auto func = final_scales
19011896
? &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::DEFAULT>
19021897
: &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::NO_SCALE>;
1898+
1899+
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
1900+
int32_t const maxBlocksPerSM
1901+
= tensorrt_llm::common::getMaxActiveBlocksPerSM(func, FINALIZE_THREADS_PER_BLOCK, 0);
1902+
int32_t const blocks = std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(num_rows * experts_per_token));
1903+
int32_t const threads = FINALIZE_THREADS_PER_BLOCK;
1904+
1905+
config.gridDim = blocks;
1906+
config.blockDim = threads;
19031907
cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales,
19041908
unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts,
19051909
expert_first_token_offset, num_rows, padded_cols, unpadded_cols, experts_per_token, num_experts_per_node,
@@ -2235,11 +2239,6 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
22352239
int64_t num_padding_tokens = 0;
22362240
#endif
22372241

2238-
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
2239-
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
2240-
int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens));
2241-
int64_t const threads = ACTIVATION_THREADS_PER_BLOCK;
2242-
22432242
auto fn = [&]()
22442243
{
22452244
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
@@ -2302,6 +2301,12 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
23022301
}
23032302
}();
23042303

2304+
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
2305+
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(fn, ACTIVATION_THREADS_PER_BLOCK, 0);
2306+
int32_t const blocks
2307+
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(expanded_num_tokens, num_padding_tokens)));
2308+
int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;
2309+
23052310
cudaLaunchConfig_t config;
23062311
config.gridDim = blocks;
23072312
config.blockDim = threads;

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,9 @@ void run(Data& data, void* stream)
647647
//
648648
// The upper bound is a strict requirement. The number of blocks should be determined by querying
649649
// the device properties, or conservatively low.
650-
static int const numBlocksCoop = tensorrt_llm::common::getMultiProcessorCount();
650+
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
651+
// WAR: Reserve 8 SMs for overlapping kernels.
652+
int const numBlocksCoop = smCount - 8;
651653

652654
// Maximum number of tokens supported by the kernel using a cooperative launch.
653655
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;

0 commit comments

Comments
 (0)