@@ -366,70 +366,68 @@ void max_pool2d_with_indices_out_cuda_template(
366366 AT_DISPATCH_FLOATING_TYPES_AND2 (kHalf , kBFloat16 , input.scalar_type (),
367367 " max_pool2d_with_indices_out_cuda_frame" ,
368368 [&] {
369- AT_SKIP_BFLOAT16_IF_NOT_ROCM (scalar_t , " max_pool2d_with_indices_out_cuda_frame" , [&] {
370- using accscalar_t = acc_type<scalar_t , true >;
371-
372- scalar_t *output_data = output.data_ptr <scalar_t >();
373- scalar_t *input_data = input.data_ptr <scalar_t >();
374- int64_t *indices_data = indices.data_ptr <int64_t >();
375-
376- switch (memory_format) {
377- case MemoryFormat::ChannelsLast: {
378- const int max_threads = std::min<int >(
379- at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , CUDA_MAX_THREADS);
380- int * maxThreadsDim = at::cuda::getCurrentDeviceProperties ()->maxThreadsDim ;
381- int block_x = std::min<int >(
382- maxThreadsDim[0 ], std::min<int >(lastPow2 (nInputPlane), at::cuda::warp_size ()));
383- int block_y = std::min<int >(
384- maxThreadsDim[1 ], std::min<int >(lastPow2 (outputWidth), max_threads / block_x));
385- int block_z = std::min<int >(
386- maxThreadsDim[2 ], std::min<int >(lastPow2 (outputHeight), max_threads / block_x / block_y));
387- block_x = std::min<int >(
388- maxThreadsDim[0 ], std::min<int >(lastPow2 (nInputPlane), max_threads / block_y / block_z));
389- const dim3 block (block_x, block_y, block_z);
390-
391- int kernel_stride_C = cuda::ATenCeilDiv (
392- safe_downcast<int , int64_t >(nInputPlane), block_x * 4 );
393- int kernel_size_C = cuda::ATenCeilDiv (
394- safe_downcast<int , int64_t >(nInputPlane), block_x * kernel_stride_C);
395-
396- int grid_x = nbatch*kernel_stride_C;
397- int grid_y = std::min<int >(
398- at::cuda::getCurrentDeviceProperties ()->maxGridSize [1 ],
399- cuda::ATenCeilDiv (safe_downcast<int , int64_t >(outputWidth), block_y*BLOCK_STRIDE));
400- int grid_z = std::min<int >(
401- at::cuda::getCurrentDeviceProperties ()->maxGridSize [2 ],
402- cuda::ATenCeilDiv (safe_downcast<int , int64_t >(outputHeight), block_z*BLOCK_STRIDE));
403- const dim3 grid (grid_x, grid_y, grid_z);
404-
405- size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof (int ) + sizeof (scalar_t ));
406- AT_ASSERT (shmem_size <= at::cuda::getCurrentDeviceProperties ()->sharedMemPerBlock );
407-
408- max_pool_forward_nhwc<scalar_t , scalar_t >
409- <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
410- input_data, nbatch,
411- nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
412- kH , kW , dH, dW, padH, padW, dilationH, dilationW,
413- in_stride_n, in_stride_c,
414- in_stride_h, in_stride_w,
415- kernel_stride_C, kernel_size_C,
416- output_data, indices_data);
417- break ;
418- }
419- case MemoryFormat::Contiguous: {
420- const int num_threads = std::min (at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock ,
421- BLOCK_THREADS);
422- max_pool_forward_nchw<scalar_t , scalar_t >
423- <<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
424- count, input_data,
425- nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
426- kH , kW , dH, dW, padH, padW, dilationH, dilationW,
427- output_data, indices_data);
428- break ;
429- }
430- default : TORCH_CHECK (false , " Unsupported memory format. Supports only ChannelsLast, Contiguous" );
369+ using accscalar_t = acc_type<scalar_t , true >;
370+
371+ scalar_t *output_data = output.data_ptr <scalar_t >();
372+ scalar_t *input_data = input.data_ptr <scalar_t >();
373+ int64_t *indices_data = indices.data_ptr <int64_t >();
374+
375+ switch (memory_format) {
376+ case MemoryFormat::ChannelsLast: {
377+ const int max_threads = std::min<int >(
378+ at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , CUDA_MAX_THREADS);
379+ int * maxThreadsDim = at::cuda::getCurrentDeviceProperties ()->maxThreadsDim ;
380+ int block_x = std::min<int >(
381+ maxThreadsDim[0 ], std::min<int >(lastPow2 (nInputPlane), at::cuda::warp_size ()));
382+ int block_y = std::min<int >(
383+ maxThreadsDim[1 ], std::min<int >(lastPow2 (outputWidth), max_threads / block_x));
384+ int block_z = std::min<int >(
385+ maxThreadsDim[2 ], std::min<int >(lastPow2 (outputHeight), max_threads / block_x / block_y));
386+ block_x = std::min<int >(
387+ maxThreadsDim[0 ], std::min<int >(lastPow2 (nInputPlane), max_threads / block_y / block_z));
388+ const dim3 block (block_x, block_y, block_z);
389+
390+ int kernel_stride_C = cuda::ATenCeilDiv (
391+ safe_downcast<int , int64_t >(nInputPlane), block_x * 4 );
392+ int kernel_size_C = cuda::ATenCeilDiv (
393+ safe_downcast<int , int64_t >(nInputPlane), block_x * kernel_stride_C);
394+
395+ int grid_x = nbatch*kernel_stride_C;
396+ int grid_y = std::min<int >(
397+ at::cuda::getCurrentDeviceProperties ()->maxGridSize [1 ],
398+ cuda::ATenCeilDiv (safe_downcast<int , int64_t >(outputWidth), block_y*BLOCK_STRIDE));
399+ int grid_z = std::min<int >(
400+ at::cuda::getCurrentDeviceProperties ()->maxGridSize [2 ],
401+ cuda::ATenCeilDiv (safe_downcast<int , int64_t >(outputHeight), block_z*BLOCK_STRIDE));
402+ const dim3 grid (grid_x, grid_y, grid_z);
403+
404+ size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof (int ) + sizeof (scalar_t ));
405+ AT_ASSERT (shmem_size <= at::cuda::getCurrentDeviceProperties ()->sharedMemPerBlock );
406+
407+ max_pool_forward_nhwc<scalar_t , scalar_t >
408+ <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
409+ input_data, nbatch,
410+ nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
411+ kH , kW , dH, dW, padH, padW, dilationH, dilationW,
412+ in_stride_n, in_stride_c,
413+ in_stride_h, in_stride_w,
414+ kernel_stride_C, kernel_size_C,
415+ output_data, indices_data);
416+ break ;
431417 }
432- });
418+ case MemoryFormat::Contiguous: {
419+ const int num_threads = std::min (at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock ,
420+ BLOCK_THREADS);
421+ max_pool_forward_nchw<scalar_t , scalar_t >
422+ <<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
423+ count, input_data,
424+ nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
425+ kH , kW , dH, dW, padH, padW, dilationH, dilationW,
426+ output_data, indices_data);
427+ break ;
428+ }
429+ default : TORCH_CHECK (false , " Unsupported memory format. Supports only ChannelsLast, Contiguous" );
430+ }
433431 }
434432 );
435433
@@ -532,88 +530,86 @@ void max_pool2d_with_indices_backward_out_cuda_template(
532530 AT_DISPATCH_FLOATING_TYPES_AND2 (kHalf , kBFloat16 , input.scalar_type (),
533531 " max_pool2d_with_indices_out_cuda_frame" ,
534532 [&] {
535- AT_SKIP_BFLOAT16_IF_NOT_ROCM (scalar_t , " max_pool2d_with_indices_out_cuda_frame" , [&] {
536- using accscalar_t = acc_type<scalar_t , true >;
537-
538- scalar_t *gradOutput_data = gradOutput.data_ptr <scalar_t >();
539- scalar_t *gradInput_data = gradInput.data_ptr <scalar_t >();
540- int64_t *indices_data = indices.data_ptr <int64_t >();
541-
542- switch (memory_format) {
543- case MemoryFormat::ChannelsLast: {
544- const int max_threads = std::min<int >(at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , CUDA_MAX_THREADS);
545- int * maxThreadsDim = at::cuda::getCurrentDeviceProperties ()->maxThreadsDim ;
546- int block_x = std::min<int >(
547- maxThreadsDim[0 ], std::min<int >(lastPow2 (nInputPlane), at::cuda::warp_size ()));
548- int block_y = std::min<int >(
549- maxThreadsDim[1 ], std::min<int >(lastPow2 (inputWidth), max_threads / block_x));
550- int block_z = std::min<int >(
551- maxThreadsDim[2 ], std::min<int >(lastPow2 (inputHeight), max_threads / block_x / block_y));
552- block_x = std::min<int >(
553- maxThreadsDim[0 ], std::min<int >(lastPow2 (nInputPlane), max_threads / block_y / block_z));
554- const dim3 block (block_x, block_y, block_z);
555-
556- int kernel_stride_C = cuda::ATenCeilDiv (
557- safe_downcast<int , int64_t >(nInputPlane), block_x * 4 );
558- int kernel_size_C = cuda::ATenCeilDiv (
559- safe_downcast<int , int64_t >(nInputPlane), block_x * kernel_stride_C);
560-
561- int grid_x = nbatch*kernel_stride_C;
562- int grid_y = std::min<int >(
563- at::cuda::getCurrentDeviceProperties ()->maxGridSize [1 ],
564- cuda::ATenCeilDiv (safe_downcast<int , int64_t >(inputWidth), block_y*BLOCK_STRIDE));
565- int grid_z = std::min<int >(
566- at::cuda::getCurrentDeviceProperties ()->maxGridSize [2 ],
567- cuda::ATenCeilDiv (safe_downcast<int , int64_t >(inputHeight), block_z*BLOCK_STRIDE));
568- const dim3 grid (grid_x, grid_y, grid_z);
569-
570- size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof (accscalar_t );
571- AT_ASSERT (shmem_size <= at::cuda::getCurrentDeviceProperties ()->sharedMemPerBlock );
572-
573- // The backward kernel is launched on input instead output.
574- // If it is launched on output layer, atomic_add would not provide much benefit on FP16.
575- // Please check comments at https://github.com/pytorch/pytorch/pull/34519.
576- max_pool_backward_nhwc<scalar_t , accscalar_t >
577- <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
578- count,
579- gradOutput_data,
580- indices_data,
581- nbatch,
582- nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
583- kH , kW , dH, dW, padH, padW, dilationH, dilationW,
584- out_stride_c, out_stride_h, out_stride_w,
585- in_stride_n, in_stride_c,
586- in_stride_h, in_stride_w,
587- kernel_stride_C, kernel_size_C,
588- gradInput_data);
589- break ;
590- }
591- case MemoryFormat::Contiguous: {
592- int imgcount = inputWidth * inputHeight;
593- dim3 grid;
594- const int blocks = (imgcount + BLOCK_THREADS - 1 ) / BLOCK_THREADS;
595- grid.x = blocks;
596- grid.y = nbatch;
597- uint64_t maxGridY = at::cuda::getCurrentDeviceProperties ()->maxGridSize [1 ];
598- if (maxGridY < grid.y ) grid.y = maxGridY;
599- grid.z = nInputPlane;
600- uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties ()->maxGridSize [2 ];
601- if (maxGridZ < grid.z ) grid.z = maxGridZ;
602-
603- max_pool_backward_nchw<scalar_t , accscalar_t >
604- <<<grid, BLOCK_THREADS, 0 , at::cuda::getCurrentCUDAStream()>>> (
605- count,
606- gradOutput_data,
607- indices_data,
608- nbatch,
609- nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
610- kH , kW , dH, dW, padH, padW, dilationH, dilationW,
611- gradInput_data);
612- break ;
613- }
614- default : TORCH_CHECK (false , " Unsupported memory format. Supports only ChannelsLast, Contiguous" );
533+ using accscalar_t = acc_type<scalar_t , true >;
534+
535+ scalar_t *gradOutput_data = gradOutput.data_ptr <scalar_t >();
536+ scalar_t *gradInput_data = gradInput.data_ptr <scalar_t >();
537+ int64_t *indices_data = indices.data_ptr <int64_t >();
538+
539+ switch (memory_format) {
540+ case MemoryFormat::ChannelsLast: {
541+ const int max_threads = std::min<int >(at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , CUDA_MAX_THREADS);
542+ int * maxThreadsDim = at::cuda::getCurrentDeviceProperties ()->maxThreadsDim ;
543+ int block_x = std::min<int >(
544+ maxThreadsDim[0 ], std::min<int >(lastPow2 (nInputPlane), at::cuda::warp_size ()));
545+ int block_y = std::min<int >(
546+ maxThreadsDim[1 ], std::min<int >(lastPow2 (inputWidth), max_threads / block_x));
547+ int block_z = std::min<int >(
548+ maxThreadsDim[2 ], std::min<int >(lastPow2 (inputHeight), max_threads / block_x / block_y));
549+ block_x = std::min<int >(
550+ maxThreadsDim[0 ], std::min<int >(lastPow2 (nInputPlane), max_threads / block_y / block_z));
551+ const dim3 block (block_x, block_y, block_z);
552+
553+ int kernel_stride_C = cuda::ATenCeilDiv (
554+ safe_downcast<int , int64_t >(nInputPlane), block_x * 4 );
555+ int kernel_size_C = cuda::ATenCeilDiv (
556+ safe_downcast<int , int64_t >(nInputPlane), block_x * kernel_stride_C);
557+
558+ int grid_x = nbatch*kernel_stride_C;
559+ int grid_y = std::min<int >(
560+ at::cuda::getCurrentDeviceProperties ()->maxGridSize [1 ],
561+ cuda::ATenCeilDiv (safe_downcast<int , int64_t >(inputWidth), block_y*BLOCK_STRIDE));
562+ int grid_z = std::min<int >(
563+ at::cuda::getCurrentDeviceProperties ()->maxGridSize [2 ],
564+ cuda::ATenCeilDiv (safe_downcast<int , int64_t >(inputHeight), block_z*BLOCK_STRIDE));
565+ const dim3 grid (grid_x, grid_y, grid_z);
566+
567+ size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof (accscalar_t );
568+ AT_ASSERT (shmem_size <= at::cuda::getCurrentDeviceProperties ()->sharedMemPerBlock );
569+
570+ // The backward kernel is launched on input instead output.
571+ // If it is launched on output layer, atomic_add would not provide much benefit on FP16.
572+ // Please check comments at https://github.com/pytorch/pytorch/pull/34519.
573+ max_pool_backward_nhwc<scalar_t , accscalar_t >
574+ <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
575+ count,
576+ gradOutput_data,
577+ indices_data,
578+ nbatch,
579+ nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
580+ kH , kW , dH, dW, padH, padW, dilationH, dilationW,
581+ out_stride_c, out_stride_h, out_stride_w,
582+ in_stride_n, in_stride_c,
583+ in_stride_h, in_stride_w,
584+ kernel_stride_C, kernel_size_C,
585+ gradInput_data);
586+ break ;
615587 }
616- });
588+ case MemoryFormat::Contiguous: {
589+ int imgcount = inputWidth * inputHeight;
590+ dim3 grid;
591+ const int blocks = (imgcount + BLOCK_THREADS - 1 ) / BLOCK_THREADS;
592+ grid.x = blocks;
593+ grid.y = nbatch;
594+ uint64_t maxGridY = at::cuda::getCurrentDeviceProperties ()->maxGridSize [1 ];
595+ if (maxGridY < grid.y ) grid.y = maxGridY;
596+ grid.z = nInputPlane;
597+ uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties ()->maxGridSize [2 ];
598+ if (maxGridZ < grid.z ) grid.z = maxGridZ;
599+
600+ max_pool_backward_nchw<scalar_t , accscalar_t >
601+ <<<grid, BLOCK_THREADS, 0 , at::cuda::getCurrentCUDAStream()>>> (
602+ count,
603+ gradOutput_data,
604+ indices_data,
605+ nbatch,
606+ nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
607+ kH , kW , dH, dW, padH, padW, dilationH, dilationW,
608+ gradInput_data);
609+ break ;
610+ }
611+ default : TORCH_CHECK (false , " Unsupported memory format. Supports only ChannelsLast, Contiguous" );
612+ }
617613 }
618614 );
619615
0 commit comments