@@ -262,14 +262,6 @@ void avg_pool2d_out_cuda_template(
262262 const int padH = safe_downcast<int , int64_t >(padding[0 ]);
263263 const int padW = padding.size () == 1 ? padH : safe_downcast<int , int64_t >(padding[1 ]);
264264
265- const auto memory_format = input_.suggest_memory_format ();
266- if (memory_format == at::MemoryFormat::ChannelsLast){
267- TORCH_CHECK (input_.ndimension () == 4 ,
268- " non-empty 4D (batch mode) tensor expected for input with channels_last layout" );
269- } else {
270- TORCH_CHECK ((input_.ndimension () == 3 || input_.ndimension () == 4 ),
271- " non-empty 3D or 4D (batch mode) tensor expected for input" );
272- }
273265
274266 TORCH_CHECK (!divisor_override.has_value () || divisor_override.value () != 0 ,
275267 " divisor must be not zero" );
@@ -281,13 +273,14 @@ void avg_pool2d_out_cuda_template(
281273
282274 const int64_t outputWidth = pooling_output_shape<int64_t >(inputWidth, kW , padW, dW, 1 , ceil_mode);
283275 const int64_t outputHeight = pooling_output_shape<int64_t >(inputHeight, kH , padH, dH, 1 , ceil_mode);
276+ const auto memory_format = input_.suggest_memory_format ();
284277
285278 pool2d_shape_check (
286279 input_,
287280 kH , kW , dH, dW, padH, padW, 1 , 1 ,
288281 nInputPlane,
289282 inputHeight, inputWidth,
290- outputHeight, outputWidth);
283+ outputHeight, outputWidth, memory_format );
291284
292285 Tensor input = input_.contiguous (memory_format);
293286
@@ -300,18 +293,36 @@ void avg_pool2d_out_cuda_template(
300293 bool use_divisor = divisor_override.has_value ();
301294 const auto divisor_override_value = use_divisor ? divisor_override.value () : 0 ;
302295
303- AT_DISPATCH_FLOATING_TYPES_AND2 (kHalf , kBFloat16 , input.scalar_type (),
304- " avg_pool2d_out_cuda_frame" ,
305- [&] {
306- using accscalar_t = acc_type<scalar_t , true >;
307-
308- scalar_t *output_data = output.data_ptr <scalar_t >();
309- scalar_t *input_data = input.data_ptr <scalar_t >();
310-
311- switch (memory_format){
312- case MemoryFormat::ChannelsLast: {
313- output.unsafeGetTensorImpl ()->empty_tensor_restride (MemoryFormat::ChannelsLast);
314- avg_pool2d_out_cuda_frame_nhwc<scalar_t , accscalar_t >
296+ if (count != 0 ) {
297+ AT_DISPATCH_FLOATING_TYPES_AND2 (kHalf , kBFloat16 , input.scalar_type (),
298+ " avg_pool2d_out_cuda_frame" ,
299+ [&] {
300+ using accscalar_t = acc_type<scalar_t , true >;
301+
302+ scalar_t *output_data = output.data_ptr <scalar_t >();
303+ scalar_t *input_data = input.data_ptr <scalar_t >();
304+
305+ switch (memory_format){
306+ case MemoryFormat::ChannelsLast: {
307+ output.unsafeGetTensorImpl ()->empty_tensor_restride (MemoryFormat::ChannelsLast);
308+ avg_pool2d_out_cuda_frame_nhwc<scalar_t , accscalar_t >
309+ <<<num_blocks, num_threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
310+ count,
311+ input_data,
312+ nbatch,
313+ nInputPlane,
314+ inputHeight, inputWidth,
315+ outputHeight, outputWidth,
316+ kH , kW ,
317+ dH, dW,
318+ padH, padW,
319+ output_data,
320+ divisor_override_value,
321+ count_include_pad, use_divisor);
322+ break ;
323+ }
324+ case MemoryFormat::Contiguous: {
325+ avg_pool2d_out_cuda_frame<scalar_t , accscalar_t >
315326 <<<num_blocks, num_threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
316327 count,
317328 input_data,
@@ -325,31 +336,13 @@ void avg_pool2d_out_cuda_template(
325336 output_data,
326337 divisor_override_value,
327338 count_include_pad, use_divisor);
328- TORCH_CUDA_KERNEL_LAUNCH_CHECK ();
329- break ;
330- }
331- case MemoryFormat::Contiguous: {
332- avg_pool2d_out_cuda_frame<scalar_t , accscalar_t >
333- <<<num_blocks, num_threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
334- count,
335- input_data,
336- nbatch,
337- nInputPlane,
338- inputHeight, inputWidth,
339- outputHeight, outputWidth,
340- kH , kW ,
341- dH, dW,
342- padH, padW,
343- output_data,
344- divisor_override_value,
345- count_include_pad, use_divisor);
346- TORCH_CUDA_KERNEL_LAUNCH_CHECK ();
347- break ;
339+ break ;
340+ }
341+ default : TORCH_CHECK (false , " Unsupported memory format. Supports only ChannelsLast, Contiguous" );
348342 }
349- default : TORCH_CHECK (false , " Unsupported memory format. Supports only ChannelsLast, Contiguous" );
350343 }
351- }
352- );
344+ );
345+ }
353346
354347 if (input.ndimension () == 3 ) {
355348 output.resize_ ({nInputPlane, outputHeight, outputWidth});
@@ -395,14 +388,6 @@ Tensor& avg_pool2d_backward_out_cuda_template(
395388 " divisor must be not zero" );
396389
397390 const auto memory_format = input_.suggest_memory_format ();
398- if (memory_format == at::MemoryFormat::ChannelsLast) {
399- TORCH_CHECK (input_.ndimension () == 4 ,
400- " non-empty 4D (batch mode) tensor expected for input with channels_last layout" );
401- } else {
402- TORCH_CHECK ((input_.ndimension () == 3 || input_.ndimension () == 4 ),
403- " non-empty 3D or 4D (batch mode) tensor expected for input" );
404- }
405-
406391 const Tensor input = input_.contiguous (memory_format);
407392 const Tensor gradOutput = gradOutput_.contiguous (memory_format);
408393
@@ -421,11 +406,14 @@ Tensor& avg_pool2d_backward_out_cuda_template(
421406 kH , kW , dH, dW, padH, padW,
422407 nInputPlane,
423408 inputHeight, inputWidth,
424- outputHeight, outputWidth);
409+ outputHeight, outputWidth, memory_format );
425410
426411 gradInput.resize_as_ (input);
427-
428412 const int32_t count = safe_downcast<int32_t , int64_t >(input.numel ());
413+ if (count == 0 ) {
414+ return gradInput;
415+ }
416+
429417 const uint32_t num_threads = std::min (at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , 1024 );
430418 const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t >(count, num_threads);
431419
0 commit comments