@@ -226,9 +226,15 @@ C10_LAUNCH_BOUNDS_1(num_threads())
226226__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
227227 using traits = function_traits<func_t >;
228228 constexpr auto io_size = calc_io_size<func_t >();
229- int remaining = N - io_block_work_size<io_size>() * blockIdx .x ;
229+ #ifdef __gfx942__
230+ constexpr int tws = (io_size >= 2 ) ? 8 : 16 ;
231+ #else
232+ constexpr int tws = elems_per_thread<io_size>();
233+ #endif
234+ constexpr int bws = tws * num_threads ();
235+ int remaining = N - bws * blockIdx .x ;
230236
231- if (remaining < io_block_work_size<io_size>() ) { // if this block handles the reminder,
237+ if (remaining < bws ) { // if this block handles the reminder,
232238 // just do a naive unrolled loop
233239 auto input_calc = TrivialOffsetCalculator<traits::arity>();
234240 auto output_calc = TrivialOffsetCalculator<1 >();
@@ -240,14 +246,14 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
240246 decltype (output_calc),
241247 memory::LoadWithoutCast,
242248 memory::StoreWithoutCast,
243- elems_per_thread<io_size>() >(
249+ tws >(
244250 data, remaining, input_calc, output_calc, loader, storer);
245251 elementwise_kernel_helper (f, policy);
246252 } else { // if this block has a full `block_work_size` data to handle, use
247253 // vectorized memory access
248254 constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
249255 elementwise_kernel_helper (
250- f, memory::policies::vectorized<optimal_vec_size, array_t , elems_per_thread<io_size>() >(data));
256+ f, memory::policies::vectorized<optimal_vec_size, array_t , tws >(data));
251257 }
252258}
253259#endif // USE_ROCM
@@ -285,10 +291,12 @@ static inline void launch_vectorized_kernel(
285291 TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
286292 using traits = function_traits<func_t >;
287293 constexpr auto io_size = calc_io_size<func_t >();
288- int64_t grid = (N + io_block_work_size<io_size>() - 1 ) / io_block_work_size<io_size>();
289294 auto stream = at::cuda::getCurrentCUDAStream ();
290295#ifdef USE_ROCM
291296 int vec_size = memory::can_vectorize_up_to<func_t >(data);
297+ c10::DeviceIndex curDevice = -1 ;
298+ AT_CUDA_CHECK (c10::cuda::GetDevice (&curDevice));
299+ int tws = at::detail::getCUDAHooks ().isGPUArch ({" gfx942" }, curDevice) ? ((io_size >= 2 ) ? 8 : 16 ) : elems_per_thread<io_size>();
292300#else
293301 using cpp_type = typename function_traits<func_t >::result_type;
294302 const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t >(data);
@@ -305,7 +313,10 @@ static inline void launch_vectorized_kernel(
305313 if constexpr (sizeof (cpp_type) < 2 ) {
306314 vec_size = std::min<uint16_t >(vec_size, 4 );
307315 }
316+ int tws = elems_per_thread<io_size>();
308317#endif
318+ int bws = tws * num_threads ();
319+ int64_t grid = (N + bws - 1 ) / bws;
309320 switch (vec_size) {
310321#ifdef USE_ROCM
311322 case 16 :
@@ -334,8 +345,9 @@ static inline void launch_vectorized_kernel(
334345 auto output_calc = TrivialOffsetCalculator<1 >();
335346 auto loader = memory::LoadWithoutCast ();
336347 auto storer = memory::StoreWithoutCast ();
348+ int64_t grid_unrolled = (N + io_block_work_size<io_size>() - 1 ) / io_block_work_size<io_size>();
337349 unrolled_elementwise_kernel<func_t , array_t , elems_per_thread<io_size>()>
338- <<<grid , num_threads(), 0 , stream>>> (
350+ <<<grid_unrolled , num_threads(), 0 , stream>>> (
339351 N, f, data, input_calc, output_calc, loader, storer);
340352 C10_CUDA_KERNEL_LAUNCH_CHECK ();
341353 break ;
0 commit comments