Skip to content

Commit 67a19fe

Browse files
zasdfgbnmfacebook-github-bot
authored andcommitted
CUDA BFloat16 pooling (#45151)
Summary: Pull Request resolved: #45151 Reviewed By: ailzhang Differential Revision: D23854056 Pulled By: ngimel fbshipit-source-id: 32f0835218c2602a09654a9ac2d161c4eb360f90
1 parent 666223d commit 67a19fe

File tree

3 files changed

+172
-179
lines changed

3 files changed

+172
-179
lines changed

aten/src/ATen/native/cuda/DilatedMaxPool2d.cu

Lines changed: 140 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -366,70 +366,68 @@ void max_pool2d_with_indices_out_cuda_template(
366366
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
367367
"max_pool2d_with_indices_out_cuda_frame",
368368
[&] {
369-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "max_pool2d_with_indices_out_cuda_frame", [&] {
370-
using accscalar_t = acc_type<scalar_t, true>;
371-
372-
scalar_t *output_data = output.data_ptr<scalar_t>();
373-
scalar_t *input_data = input.data_ptr<scalar_t>();
374-
int64_t *indices_data = indices.data_ptr<int64_t>();
375-
376-
switch (memory_format) {
377-
case MemoryFormat::ChannelsLast: {
378-
const int max_threads = std::min<int>(
379-
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
380-
int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
381-
int block_x = std::min<int>(
382-
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
383-
int block_y = std::min<int>(
384-
maxThreadsDim[1], std::min<int>(lastPow2(outputWidth), max_threads / block_x));
385-
int block_z = std::min<int>(
386-
maxThreadsDim[2], std::min<int>(lastPow2(outputHeight), max_threads / block_x / block_y));
387-
block_x = std::min<int>(
388-
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
389-
const dim3 block(block_x, block_y, block_z);
390-
391-
int kernel_stride_C = cuda::ATenCeilDiv(
392-
safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
393-
int kernel_size_C = cuda::ATenCeilDiv(
394-
safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);
395-
396-
int grid_x = nbatch*kernel_stride_C;
397-
int grid_y = std::min<int>(
398-
at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
399-
cuda::ATenCeilDiv(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE));
400-
int grid_z = std::min<int>(
401-
at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
402-
cuda::ATenCeilDiv(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE));
403-
const dim3 grid(grid_x, grid_y, grid_z);
404-
405-
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
406-
AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
407-
408-
max_pool_forward_nhwc<scalar_t, scalar_t>
409-
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
410-
input_data, nbatch,
411-
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
412-
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
413-
in_stride_n, in_stride_c,
414-
in_stride_h, in_stride_w,
415-
kernel_stride_C, kernel_size_C,
416-
output_data, indices_data);
417-
break;
418-
}
419-
case MemoryFormat::Contiguous: {
420-
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
421-
BLOCK_THREADS);
422-
max_pool_forward_nchw<scalar_t, scalar_t>
423-
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
424-
count, input_data,
425-
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
426-
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
427-
output_data, indices_data);
428-
break;
429-
}
430-
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
369+
using accscalar_t = acc_type<scalar_t, true>;
370+
371+
scalar_t *output_data = output.data_ptr<scalar_t>();
372+
scalar_t *input_data = input.data_ptr<scalar_t>();
373+
int64_t *indices_data = indices.data_ptr<int64_t>();
374+
375+
switch (memory_format) {
376+
case MemoryFormat::ChannelsLast: {
377+
const int max_threads = std::min<int>(
378+
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
379+
int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
380+
int block_x = std::min<int>(
381+
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
382+
int block_y = std::min<int>(
383+
maxThreadsDim[1], std::min<int>(lastPow2(outputWidth), max_threads / block_x));
384+
int block_z = std::min<int>(
385+
maxThreadsDim[2], std::min<int>(lastPow2(outputHeight), max_threads / block_x / block_y));
386+
block_x = std::min<int>(
387+
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
388+
const dim3 block(block_x, block_y, block_z);
389+
390+
int kernel_stride_C = cuda::ATenCeilDiv(
391+
safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
392+
int kernel_size_C = cuda::ATenCeilDiv(
393+
safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);
394+
395+
int grid_x = nbatch*kernel_stride_C;
396+
int grid_y = std::min<int>(
397+
at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
398+
cuda::ATenCeilDiv(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE));
399+
int grid_z = std::min<int>(
400+
at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
401+
cuda::ATenCeilDiv(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE));
402+
const dim3 grid(grid_x, grid_y, grid_z);
403+
404+
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
405+
AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
406+
407+
max_pool_forward_nhwc<scalar_t, scalar_t>
408+
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
409+
input_data, nbatch,
410+
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
411+
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
412+
in_stride_n, in_stride_c,
413+
in_stride_h, in_stride_w,
414+
kernel_stride_C, kernel_size_C,
415+
output_data, indices_data);
416+
break;
431417
}
432-
});
418+
case MemoryFormat::Contiguous: {
419+
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
420+
BLOCK_THREADS);
421+
max_pool_forward_nchw<scalar_t, scalar_t>
422+
<<<cuda::ATenCeilDiv(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
423+
count, input_data,
424+
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
425+
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
426+
output_data, indices_data);
427+
break;
428+
}
429+
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
430+
}
433431
}
434432
);
435433

@@ -532,88 +530,86 @@ void max_pool2d_with_indices_backward_out_cuda_template(
532530
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
533531
"max_pool2d_with_indices_out_cuda_frame",
534532
[&] {
535-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "max_pool2d_with_indices_out_cuda_frame", [&] {
536-
using accscalar_t = acc_type<scalar_t, true>;
537-
538-
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
539-
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
540-
int64_t *indices_data = indices.data_ptr<int64_t>();
541-
542-
switch (memory_format) {
543-
case MemoryFormat::ChannelsLast: {
544-
const int max_threads = std::min<int>(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
545-
int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
546-
int block_x = std::min<int>(
547-
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
548-
int block_y = std::min<int>(
549-
maxThreadsDim[1], std::min<int>(lastPow2(inputWidth), max_threads / block_x));
550-
int block_z = std::min<int>(
551-
maxThreadsDim[2], std::min<int>(lastPow2(inputHeight), max_threads / block_x / block_y));
552-
block_x = std::min<int>(
553-
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
554-
const dim3 block(block_x, block_y, block_z);
555-
556-
int kernel_stride_C = cuda::ATenCeilDiv(
557-
safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
558-
int kernel_size_C = cuda::ATenCeilDiv(
559-
safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);
560-
561-
int grid_x = nbatch*kernel_stride_C;
562-
int grid_y = std::min<int>(
563-
at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
564-
cuda::ATenCeilDiv(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE));
565-
int grid_z = std::min<int>(
566-
at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
567-
cuda::ATenCeilDiv(safe_downcast<int, int64_t>(inputHeight), block_z*BLOCK_STRIDE));
568-
const dim3 grid(grid_x, grid_y, grid_z);
569-
570-
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t);
571-
AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
572-
573-
// The backward kernel is launched on input instead output.
574-
// If it is launched on output layer, atomic_add would not provide much benefit on FP16.
575-
// Please check comments at https://github.com/pytorch/pytorch/pull/34519.
576-
max_pool_backward_nhwc<scalar_t, accscalar_t>
577-
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
578-
count,
579-
gradOutput_data,
580-
indices_data,
581-
nbatch,
582-
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
583-
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
584-
out_stride_c, out_stride_h, out_stride_w,
585-
in_stride_n, in_stride_c,
586-
in_stride_h, in_stride_w,
587-
kernel_stride_C, kernel_size_C,
588-
gradInput_data);
589-
break;
590-
}
591-
case MemoryFormat::Contiguous: {
592-
int imgcount = inputWidth * inputHeight;
593-
dim3 grid;
594-
const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
595-
grid.x = blocks;
596-
grid.y = nbatch;
597-
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
598-
if (maxGridY < grid.y) grid.y = maxGridY;
599-
grid.z = nInputPlane;
600-
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
601-
if (maxGridZ < grid.z) grid.z = maxGridZ;
602-
603-
max_pool_backward_nchw<scalar_t, accscalar_t>
604-
<<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
605-
count,
606-
gradOutput_data,
607-
indices_data,
608-
nbatch,
609-
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
610-
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
611-
gradInput_data);
612-
break;
613-
}
614-
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
533+
using accscalar_t = acc_type<scalar_t, true>;
534+
535+
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
536+
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
537+
int64_t *indices_data = indices.data_ptr<int64_t>();
538+
539+
switch (memory_format) {
540+
case MemoryFormat::ChannelsLast: {
541+
const int max_threads = std::min<int>(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
542+
int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
543+
int block_x = std::min<int>(
544+
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
545+
int block_y = std::min<int>(
546+
maxThreadsDim[1], std::min<int>(lastPow2(inputWidth), max_threads / block_x));
547+
int block_z = std::min<int>(
548+
maxThreadsDim[2], std::min<int>(lastPow2(inputHeight), max_threads / block_x / block_y));
549+
block_x = std::min<int>(
550+
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
551+
const dim3 block(block_x, block_y, block_z);
552+
553+
int kernel_stride_C = cuda::ATenCeilDiv(
554+
safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
555+
int kernel_size_C = cuda::ATenCeilDiv(
556+
safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);
557+
558+
int grid_x = nbatch*kernel_stride_C;
559+
int grid_y = std::min<int>(
560+
at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
561+
cuda::ATenCeilDiv(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE));
562+
int grid_z = std::min<int>(
563+
at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
564+
cuda::ATenCeilDiv(safe_downcast<int, int64_t>(inputHeight), block_z*BLOCK_STRIDE));
565+
const dim3 grid(grid_x, grid_y, grid_z);
566+
567+
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t);
568+
AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
569+
570+
// The backward kernel is launched on input instead output.
571+
// If it is launched on output layer, atomic_add would not provide much benefit on FP16.
572+
// Please check comments at https://github.com/pytorch/pytorch/pull/34519.
573+
max_pool_backward_nhwc<scalar_t, accscalar_t>
574+
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
575+
count,
576+
gradOutput_data,
577+
indices_data,
578+
nbatch,
579+
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
580+
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
581+
out_stride_c, out_stride_h, out_stride_w,
582+
in_stride_n, in_stride_c,
583+
in_stride_h, in_stride_w,
584+
kernel_stride_C, kernel_size_C,
585+
gradInput_data);
586+
break;
615587
}
616-
});
588+
case MemoryFormat::Contiguous: {
589+
int imgcount = inputWidth * inputHeight;
590+
dim3 grid;
591+
const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
592+
grid.x = blocks;
593+
grid.y = nbatch;
594+
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
595+
if (maxGridY < grid.y) grid.y = maxGridY;
596+
grid.z = nInputPlane;
597+
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
598+
if (maxGridZ < grid.z) grid.z = maxGridZ;
599+
600+
max_pool_backward_nchw<scalar_t, accscalar_t>
601+
<<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
602+
count,
603+
gradOutput_data,
604+
indices_data,
605+
nbatch,
606+
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
607+
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
608+
gradInput_data);
609+
break;
610+
}
611+
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
612+
}
617613
}
618614
);
619615

aten/src/ATen/native/cuda/DilatedMaxPool3d.cu

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,18 @@ void max_pool3d_with_indices_out_cuda_template(
276276
input.scalar_type(),
277277
"max_pool3d_with_indices_out_frame",
278278
[&]{
279-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "max_pool3d_with_indices_out_frame", [&] {
280-
scalar_t *input_data = work_input.data_ptr<scalar_t>();
281-
int64_t totalZ = otime * nslices * nbatch;
282-
283-
max_pool3d_with_indices_out_frame(
284-
input_data, work_output, work_indices,
285-
totalZ,
286-
itime, iheight, iwidth,
287-
otime, oheight, owidth,
288-
kT, kH, kW,
289-
dT, dH, dW,
290-
pT, pH, pW,
291-
dilationT, dilationH, dilationW);
292-
});
279+
scalar_t *input_data = work_input.data_ptr<scalar_t>();
280+
int64_t totalZ = otime * nslices * nbatch;
281+
282+
max_pool3d_with_indices_out_frame(
283+
input_data, work_output, work_indices,
284+
totalZ,
285+
itime, iheight, iwidth,
286+
otime, oheight, owidth,
287+
kT, kH, kW,
288+
dT, dH, dW,
289+
pT, pH, pW,
290+
dilationT, dilationH, dilationW);
293291
}
294292
);
295293
}
@@ -387,19 +385,17 @@ void max_pool3d_with_indices_backward_out_cuda_template(
387385
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
388386
"max_pool3d_with_indices_backward_out_frame",
389387
[&] {
390-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "max_pool3d_with_indices_backward_out_frame", [&] {
391-
const int64_t totalZ = otime * nslices * nbatch;
392-
scalar_t *grad_input_data = work_grad_input.data_ptr<scalar_t>();
393-
394-
max_pool3d_with_indices_backward_out_frame(
395-
grad_input_data, work_grad_output, work_indices,
396-
totalZ,
397-
itime, iheight, iwidth,
398-
oheight, owidth,
399-
dT, dH, dW,
400-
pT, pH, pW,
401-
dilationT, dilationH, dilationW);
402-
});
388+
const int64_t totalZ = otime * nslices * nbatch;
389+
scalar_t *grad_input_data = work_grad_input.data_ptr<scalar_t>();
390+
391+
max_pool3d_with_indices_backward_out_frame(
392+
grad_input_data, work_grad_output, work_indices,
393+
totalZ,
394+
itime, iheight, iwidth,
395+
oheight, owidth,
396+
dT, dH, dW,
397+
pT, pH, pW,
398+
dilationT, dilationH, dilationW);
403399
}
404400
);
405401
}

0 commit comments

Comments
 (0)