@@ -262,6 +262,14 @@ void avg_pool2d_out_cuda_template(
262262 const int padH = safe_downcast<int , int64_t >(padding[0 ]);
263263 const int padW = padding.size () == 1 ? padH : safe_downcast<int , int64_t >(padding[1 ]);
264264
265+ const auto memory_format = input_.suggest_memory_format ();
266+ if (memory_format == at::MemoryFormat::ChannelsLast){
267+ TORCH_CHECK (input_.ndimension () == 4 ,
268+ " non-empty 4D (batch mode) tensor expected for input with channels_last layout" );
269+ } else {
270+ TORCH_CHECK ((input_.ndimension () == 3 || input_.ndimension () == 4 ),
271+ " non-empty 3D or 4D (batch mode) tensor expected for input" );
272+ }
265273
266274 TORCH_CHECK (!divisor_override.has_value () || divisor_override.value () != 0 ,
267275 " divisor must be not zero" );
@@ -273,14 +281,13 @@ void avg_pool2d_out_cuda_template(
273281
274282 const int64_t outputWidth = pooling_output_shape<int64_t >(inputWidth, kW , padW, dW, 1 , ceil_mode);
275283 const int64_t outputHeight = pooling_output_shape<int64_t >(inputHeight, kH , padH, dH, 1 , ceil_mode);
276- const auto memory_format = input_.suggest_memory_format ();
277284
278285 pool2d_shape_check (
279286 input_,
280287 kH , kW , dH, dW, padH, padW, 1 , 1 ,
281288 nInputPlane,
282289 inputHeight, inputWidth,
283- outputHeight, outputWidth, memory_format );
290+ outputHeight, outputWidth);
284291
285292 Tensor input = input_.contiguous (memory_format);
286293
@@ -293,36 +300,18 @@ void avg_pool2d_out_cuda_template(
293300 bool use_divisor = divisor_override.has_value ();
294301 const auto divisor_override_value = use_divisor ? divisor_override.value () : 0 ;
295302
296- if (count != 0 ) {
297- AT_DISPATCH_FLOATING_TYPES_AND2 (kHalf , kBFloat16 , input.scalar_type (),
298- " avg_pool2d_out_cuda_frame" ,
299- [&] {
300- using accscalar_t = acc_type<scalar_t , true >;
301-
302- scalar_t *output_data = output.data_ptr <scalar_t >();
303- scalar_t *input_data = input.data_ptr <scalar_t >();
304-
305- switch (memory_format){
306- case MemoryFormat::ChannelsLast: {
307- output.unsafeGetTensorImpl ()->empty_tensor_restride (MemoryFormat::ChannelsLast);
308- avg_pool2d_out_cuda_frame_nhwc<scalar_t , accscalar_t >
309- <<<num_blocks, num_threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
310- count,
311- input_data,
312- nbatch,
313- nInputPlane,
314- inputHeight, inputWidth,
315- outputHeight, outputWidth,
316- kH , kW ,
317- dH, dW,
318- padH, padW,
319- output_data,
320- divisor_override_value,
321- count_include_pad, use_divisor);
322- break ;
323- }
324- case MemoryFormat::Contiguous: {
325- avg_pool2d_out_cuda_frame<scalar_t , accscalar_t >
303+ AT_DISPATCH_FLOATING_TYPES_AND2 (kHalf , kBFloat16 , input.scalar_type (),
304+ " avg_pool2d_out_cuda_frame" ,
305+ [&] {
306+ using accscalar_t = acc_type<scalar_t , true >;
307+
308+ scalar_t *output_data = output.data_ptr <scalar_t >();
309+ scalar_t *input_data = input.data_ptr <scalar_t >();
310+
311+ switch (memory_format){
312+ case MemoryFormat::ChannelsLast: {
313+ output.unsafeGetTensorImpl ()->empty_tensor_restride (MemoryFormat::ChannelsLast);
314+ avg_pool2d_out_cuda_frame_nhwc<scalar_t , accscalar_t >
326315 <<<num_blocks, num_threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
327316 count,
328317 input_data,
@@ -336,13 +325,31 @@ void avg_pool2d_out_cuda_template(
336325 output_data,
337326 divisor_override_value,
338327 count_include_pad, use_divisor);
339- break ;
340- }
341- default : TORCH_CHECK (false , " Unsupported memory format. Supports only ChannelsLast, Contiguous" );
328+ TORCH_CUDA_KERNEL_LAUNCH_CHECK ();
329+ break ;
330+ }
331+ case MemoryFormat::Contiguous: {
332+ avg_pool2d_out_cuda_frame<scalar_t , accscalar_t >
333+ <<<num_blocks, num_threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
334+ count,
335+ input_data,
336+ nbatch,
337+ nInputPlane,
338+ inputHeight, inputWidth,
339+ outputHeight, outputWidth,
340+ kH , kW ,
341+ dH, dW,
342+ padH, padW,
343+ output_data,
344+ divisor_override_value,
345+ count_include_pad, use_divisor);
346+ TORCH_CUDA_KERNEL_LAUNCH_CHECK ();
347+ break ;
342348 }
349+ default : TORCH_CHECK (false , " Unsupported memory format. Supports only ChannelsLast, Contiguous" );
343350 }
344- );
345- }
351+ }
352+ );
346353
347354 if (input.ndimension () == 3 ) {
348355 output.resize_ ({nInputPlane, outputHeight, outputWidth});
@@ -388,6 +395,14 @@ Tensor& avg_pool2d_backward_out_cuda_template(
388395 " divisor must be not zero" );
389396
390397 const auto memory_format = input_.suggest_memory_format ();
398+ if (memory_format == at::MemoryFormat::ChannelsLast) {
399+ TORCH_CHECK (input_.ndimension () == 4 ,
400+ " non-empty 4D (batch mode) tensor expected for input with channels_last layout" );
401+ } else {
402+ TORCH_CHECK ((input_.ndimension () == 3 || input_.ndimension () == 4 ),
403+ " non-empty 3D or 4D (batch mode) tensor expected for input" );
404+ }
405+
391406 const Tensor input = input_.contiguous (memory_format);
392407 const Tensor gradOutput = gradOutput_.contiguous (memory_format);
393408
@@ -406,14 +421,11 @@ Tensor& avg_pool2d_backward_out_cuda_template(
406421 kH , kW , dH, dW, padH, padW,
407422 nInputPlane,
408423 inputHeight, inputWidth,
409- outputHeight, outputWidth, memory_format );
424+ outputHeight, outputWidth);
410425
411426 gradInput.resize_as_ (input);
427+
412428 const int32_t count = safe_downcast<int32_t , int64_t >(input.numel ());
413- if (count == 0 ) {
414- return gradInput;
415- }
416-
417429 const uint32_t num_threads = std::min (at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , 1024 );
418430 const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t >(count, num_threads);
419431
0 commit comments