@@ -67,7 +67,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
6767 int64_t const kCopyPerToken = hidden_size / kElemPerCopy ;
6868
6969#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
70- asm volatile ( " griddepcontrol.wait; " );
70+ cudaGridDependencySynchronize ( );
7171#endif
7272
7373 int32_t const num_tokens = num_non_exiting_tiles[0 ] * tile_size;
@@ -110,7 +110,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
110110 }
111111
112112#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
113- asm volatile ( " griddepcontrol.launch_dependents; " );
113+ cudaTriggerProgrammaticLaunchCompletion ( );
114114#endif
115115}
116116
@@ -141,12 +141,12 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const
141141 }
142142#endif
143143
144+ auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize , kThreadsPerBlock >;
144145 static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
145- int32_t const blocks = std::min (smCount * 8 , max_num_permuted_tokens);
146+ int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM (kernel, kThreadsPerBlock , 0 );
147+ int32_t const blocks = std::min (smCount * maxBlocksPerSM, max_num_permuted_tokens);
146148 int32_t const threads = kThreadsPerBlock ;
147149
148- auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize , kThreadsPerBlock >;
149-
150150 cudaLaunchConfig_t config;
151151 config.gridDim = blocks;
152152 config.blockDim = threads;
@@ -195,7 +195,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
195195 int32_t const token_idx = blockIdx .x ;
196196
197197#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
198- asm volatile ( " griddepcontrol.wait; " );
198+ cudaGridDependencySynchronize ( );
199199#endif
200200
201201 auto * dst_ptr = reinterpret_cast <ElemCopyType*>(output) + token_idx * kCopyPerToken ;
@@ -232,7 +232,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
232232 }
233233
234234#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
235- asm volatile ( " griddepcontrol.launch_dependents; " );
235+ cudaTriggerProgrammaticLaunchCompletion ( );
236236#endif
237237}
238238
@@ -277,6 +277,105 @@ INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16);
277277#endif
278278#undef INSTANTIATE_MOE_UNPERMUTE
279279
280+ template <typename InputType, int32_t kThreadsPerBlock >
281+ __global__ void moeOutputMemsetKernel (InputType* input, int32_t const * tile_idx_to_mn_limit,
282+ int32_t const * expanded_idx_to_permuted_idx, int32_t const * permuted_idx_to_expanded_idx,
283+ int32_t const * num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size)
284+ {
285+ int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
286+ int64_t const kCopyPerToken = hidden_size / kElemPerCopy ;
287+
288+ InputType rmem[kElemPerCopy ];
289+ #pragma unroll
290+ for (int32_t j = 0 ; j < kElemPerCopy ; j++)
291+ {
292+ rmem[j] = 0 ;
293+ }
294+
295+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
296+ cudaGridDependencySynchronize ();
297+ #endif
298+
299+ int32_t const num_tokens = num_non_exiting_tiles[0 ] * tile_size;
300+ for (int32_t permuted_idx = blockIdx .x ; permuted_idx < num_tokens; permuted_idx += gridDim .x )
301+ {
302+ int32_t const tile_idx = permuted_idx / tile_size;
303+ if (permuted_idx >= tile_idx_to_mn_limit[tile_idx])
304+ {
305+ continue ;
306+ }
307+ int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx];
308+ int32_t const token_idx = expanded_idx / top_k;
309+ int32_t const topk_idx = expanded_idx % top_k;
310+
311+ bool is_first_in_topk = true ;
312+ for (int32_t k = 0 ; k < topk_idx; k++)
313+ {
314+ if (expanded_idx_to_permuted_idx[token_idx * top_k + k] >= 0 )
315+ {
316+ is_first_in_topk = false ;
317+ break ;
318+ }
319+ }
320+ if (!is_first_in_topk)
321+ {
322+ continue ;
323+ }
324+
325+ auto * dst_ptr = reinterpret_cast <ElemCopyType*>(input) + token_idx * kCopyPerToken ;
326+ for (int32_t i = threadIdx .x ; i < kCopyPerToken ; i += kThreadsPerBlock )
327+ {
328+ dst_ptr[i] = *reinterpret_cast <ElemCopyType*>(rmem);
329+ }
330+ }
331+
332+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
333+ cudaTriggerProgrammaticLaunchCompletion ();
334+ #endif
335+ }
336+
337+ template <typename InputType>
338+ void moeOutputMemset (InputType* input, int32_t const * tile_idx_to_mn_limit, int32_t const * expanded_idx_to_permuted_idx,
339+ int32_t const * permuted_idx_to_expanded_idx, int32_t const * num_non_exiting_tiles,
340+ int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size,
341+ cudaStream_t stream)
342+ {
343+ int32_t constexpr kThreadsPerBlock = 256 ;
344+ int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
345+ TLLM_CHECK_WITH_INFO (hidden_size % kElemPerCopy == 0 , " hidden_size must be divisible by %d." , kElemPerCopy );
346+
347+ auto kernel = &moeOutputMemsetKernel<InputType, kThreadsPerBlock >;
348+ static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
349+ int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM (kernel, kThreadsPerBlock , 0 );
350+ int32_t const blocks = std::min (smCount * maxBlocksPerSM, max_num_permuted_tokens);
351+ int32_t const threads = kThreadsPerBlock ;
352+
353+ cudaLaunchConfig_t config;
354+ config.gridDim = blocks;
355+ config.blockDim = threads;
356+ config.dynamicSmemBytes = 0 ;
357+ config.stream = stream;
358+ cudaLaunchAttribute attrs[1 ];
359+ attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
360+ attrs[0 ].val .programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL ();
361+ config.numAttrs = 1 ;
362+ config.attrs = attrs;
363+ cudaLaunchKernelEx (&config, kernel, input, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx,
364+ permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size);
365+ }
366+
367+ #define INSTANTIATE_MOE_OUTPUT_MEMSET (InputType ) \
368+ template void moeOutputMemset<InputType>(InputType * input, int32_t const * tile_idx_to_mn_limit, \
369+ int32_t const * expanded_idx_to_permuted_idx, int32_t const * permuted_idx_to_expanded_idx, \
370+ int32_t const * num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, \
371+ int32_t const top_k, int32_t const tile_size, cudaStream_t stream)
372+
373+ INSTANTIATE_MOE_OUTPUT_MEMSET (half);
374+ #ifdef ENABLE_BF16
375+ INSTANTIATE_MOE_OUTPUT_MEMSET (__nv_bfloat16);
376+ #endif
377+ #undef INSTANTIATE_MOE_OUTPUT_MEMSET
378+
280379template <typename InputType, typename OutputType, typename SFType, int32_t kSFVecSize , typename ActFn,
281380 int32_t kThreadsPerBlock >
282381__global__ void moeActivationKernel (InputType const * input, OutputType* output, float const * global_sf,
@@ -297,7 +396,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
297396 ActFn act{};
298397
299398#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
300- asm volatile ( " griddepcontrol.wait; " );
399+ cudaGridDependencySynchronize ( );
301400#endif
302401
303402 float global_sf_val = global_sf == nullptr ? 1 .0f : global_sf[0 ];
@@ -353,7 +452,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
353452 }
354453
355454#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
356- asm volatile ( " griddepcontrol.launch_dependents; " );
455+ cudaTriggerProgrammaticLaunchCompletion ( );
357456#endif
358457}
359458
@@ -382,10 +481,6 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
382481 }
383482#endif
384483
385- static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
386- int32_t const blocks = std::min (smCount * 8 , max_num_permuted_tokens);
387- int32_t const threads = kThreadsPerBlock ;
388-
389484 auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const * input, OutputType* output,
390485 float const * global_sf, SFType* output_sf,
391486 int32_t const * tile_idx_to_mn_limit,
@@ -424,6 +519,11 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
424519 };
425520 auto kernel = get_act_kernel (activation_params.activation_type );
426521
522+ static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
523+ int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM (kernel, kThreadsPerBlock , 0 );
524+ int32_t const blocks = std::min (smCount * maxBlocksPerSM, max_num_permuted_tokens);
525+ int32_t const threads = kThreadsPerBlock ;
526+
427527 cudaLaunchConfig_t config;
428528 config.gridDim = blocks;
429529 config.blockDim = threads;
0 commit comments