Skip to content

Commit ceeab70

Browse files
v0drofacebook-github-bot
authored andcommitted
Reopen PR for 0 dim batch size for AvgPool2d. (#47426)
Summary: Resubmitting #40694 since it could not be landed for some reason. CC ngimel Pull Request resolved: #47426 Reviewed By: mruberry Differential Revision: D24941350 Pulled By: ngimel fbshipit-source-id: b7e50346d86eb63aaaf4fdd5ee71fafee2d0b476
1 parent 260daf0 commit ceeab70

File tree

7 files changed

+86
-77
lines changed

7 files changed

+86
-77
lines changed

aten/src/ATen/native/AveragePool2d.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,6 @@ void avg_pool2d_out_cpu_template(
119119
const int padH = safe_downcast<int, int64_t>(padding[0]);
120120
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
121121

122-
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
123-
"non-empty 2D or 3D (batch mode) tensor expected for input");
124-
125122
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
126123
"divisor must be not zero");
127124

@@ -139,7 +136,7 @@ void avg_pool2d_out_cpu_template(
139136
kH, kW, dH, dW, padH, padW, 1, 1,
140137
nInputPlane,
141138
inputHeight, inputWidth,
142-
outputHeight, outputWidth);
139+
outputHeight, outputWidth, input_.suggest_memory_format());
143140

144141
if (input_.ndimension() == 3) {
145142
output.resize_({nInputPlane, outputHeight, outputWidth});
@@ -276,12 +273,8 @@ Tensor& avg_pool2d_backward_out_cpu_template(
276273
"avg_pool2d: padding must either be a single int, or a tuple of two ints");
277274
const int padH = safe_downcast<int, int64_t>(padding[0]);
278275
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
279-
280276
const int64_t ndim = input.ndimension();
281277

282-
TORCH_CHECK((ndim == 3 || ndim == 4),
283-
"non-empty 3D or 4D (batch mode) tensor expected for input");
284-
285278
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
286279

287280
/* sizes */
@@ -299,7 +292,8 @@ Tensor& avg_pool2d_backward_out_cpu_template(
299292
kH, kW, dH, dW, padH, padW,
300293
nInputPlane,
301294
inputHeight, inputWidth,
302-
outputHeight, outputWidth);
295+
outputHeight, outputWidth,
296+
input.suggest_memory_format());
303297

304298
/* get contiguous gradOutput */
305299
const Tensor gradOutput = gradOutput_.contiguous();

aten/src/ATen/native/DilatedMaxPool2d.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void max_pool2d_with_indices_out_cpu_template(
169169
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
170170
nInputPlane,
171171
inputHeight, inputWidth,
172-
outputHeight, outputWidth);
172+
outputHeight, outputWidth, input_.suggest_memory_format());
173173

174174
/* get contiguous input */
175175
Tensor input = input_.contiguous();
@@ -360,7 +360,8 @@ Tensor& max_pool2d_with_indices_backward_out_cpu_template(
360360
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
361361
nInputPlane,
362362
inputHeight, inputWidth,
363-
outputHeight_for_shape_check, outputWidth_for_shape_check);
363+
outputHeight_for_shape_check, outputWidth_for_shape_check,
364+
input.suggest_memory_format());
364365

365366
/* backprop */
366367
if (input.ndimension() == 3)

aten/src/ATen/native/Pool.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ pool2d_shape_check(
5454
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
5555
int64_t nInputPlane,
5656
int64_t inputHeight, int64_t inputWidth,
57-
int64_t outputHeight, int64_t outputWidth)
57+
int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
5858
{
59-
const int64_t ndim = input.ndimension();
6059
const int64_t nOutputPlane = nInputPlane;
6160

6261
TORCH_CHECK(kW > 0 && kH > 0,
@@ -69,8 +68,19 @@ pool2d_shape_check(
6968
"dilation should be greater than zero, but got ",
7069
"dilationH: ", dilationH, " dilationW: ", dilationW);
7170

72-
TORCH_CHECK(input.numel() > 0 && (ndim == 3 || ndim == 4),
73-
"non-empty 3D or 4D input tensor expected but got ndim: ", ndim);
71+
bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
72+
if (memory_format == at::MemoryFormat::ChannelsLast){
73+
// Expect tensor in NHWC format and allow 0-dim only for N.
74+
TORCH_CHECK((input.ndimension() == 4 && valid_dims && input.size(3) != 0),
75+
"Expected 4D (batch mode) tensor expected for input with channels_last layout"
76+
" with optional 0 dim batch size for input, but got: ", input.sizes());
77+
} else {
78+
TORCH_CHECK((input.ndimension() == 3 && input.size(0) != 0 && valid_dims) ||
79+
(input.ndimension() == 4 && valid_dims && input.size(3) != 0),
80+
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
81+
input.sizes());
82+
}
83+
7484
TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
7585
"pad should be smaller than half of kernel size, but got ",
7686
"padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
@@ -93,13 +103,13 @@ max_pool2d_backward_shape_check(
93103
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
94104
int64_t nInputPlane,
95105
int64_t inputHeight, int64_t inputWidth,
96-
int64_t outputHeight, int64_t outputWidth,
106+
int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format,
97107
bool cuda=false)
98108
{
99109
pool2d_shape_check(
100110
input,
101111
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
102-
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
112+
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
103113

104114
const int64_t ndim = input.ndimension();
105115
const int64_t nOutputPlane = nInputPlane;
@@ -122,12 +132,14 @@ avg_pool2d_backward_shape_check(
122132
int kH, int kW, int dH, int dW, int padH, int padW,
123133
int64_t nInputPlane,
124134
int64_t inputHeight, int64_t inputWidth,
125-
int64_t outputHeight, int64_t outputWidth)
135+
int64_t outputHeight, int64_t outputWidth,
136+
MemoryFormat memory_format)
126137
{
127138
pool2d_shape_check(
128139
input,
129140
kH, kW, dH, dW, padH, padW, 1, 1,
130-
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
141+
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
142+
memory_format);
131143

132144
const int64_t ndim = input.ndimension();
133145
const int64_t nOutputPlane = nInputPlane;

aten/src/ATen/native/cuda/AveragePool2d.cu

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,6 @@ void avg_pool2d_out_cuda_template(
262262
const int padH = safe_downcast<int, int64_t>(padding[0]);
263263
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
264264

265-
const auto memory_format = input_.suggest_memory_format();
266-
if (memory_format == at::MemoryFormat::ChannelsLast){
267-
TORCH_CHECK(input_.ndimension() == 4,
268-
"non-empty 4D (batch mode) tensor expected for input with channels_last layout");
269-
} else {
270-
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
271-
"non-empty 3D or 4D (batch mode) tensor expected for input");
272-
}
273265

274266
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
275267
"divisor must be not zero");
@@ -281,13 +273,14 @@ void avg_pool2d_out_cuda_template(
281273

282274
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
283275
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
276+
const auto memory_format = input_.suggest_memory_format();
284277

285278
pool2d_shape_check(
286279
input_,
287280
kH, kW, dH, dW, padH, padW, 1, 1,
288281
nInputPlane,
289282
inputHeight, inputWidth,
290-
outputHeight, outputWidth);
283+
outputHeight, outputWidth, memory_format);
291284

292285
Tensor input = input_.contiguous(memory_format);
293286

@@ -300,18 +293,36 @@ void avg_pool2d_out_cuda_template(
300293
bool use_divisor = divisor_override.has_value();
301294
const auto divisor_override_value = use_divisor ? divisor_override.value() : 0;
302295

303-
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
304-
"avg_pool2d_out_cuda_frame",
305-
[&] {
306-
using accscalar_t = acc_type<scalar_t, true>;
307-
308-
scalar_t *output_data = output.data_ptr<scalar_t>();
309-
scalar_t *input_data = input.data_ptr<scalar_t>();
310-
311-
switch (memory_format){
312-
case MemoryFormat::ChannelsLast: {
313-
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
314-
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
296+
if (count != 0) {
297+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
298+
"avg_pool2d_out_cuda_frame",
299+
[&] {
300+
using accscalar_t = acc_type<scalar_t, true>;
301+
302+
scalar_t *output_data = output.data_ptr<scalar_t>();
303+
scalar_t *input_data = input.data_ptr<scalar_t>();
304+
305+
switch (memory_format){
306+
case MemoryFormat::ChannelsLast: {
307+
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
308+
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
309+
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
310+
count,
311+
input_data,
312+
nbatch,
313+
nInputPlane,
314+
inputHeight, inputWidth,
315+
outputHeight, outputWidth,
316+
kH, kW,
317+
dH, dW,
318+
padH, padW,
319+
output_data,
320+
divisor_override_value,
321+
count_include_pad, use_divisor);
322+
break;
323+
}
324+
case MemoryFormat::Contiguous: {
325+
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
315326
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
316327
count,
317328
input_data,
@@ -325,31 +336,13 @@ void avg_pool2d_out_cuda_template(
325336
output_data,
326337
divisor_override_value,
327338
count_include_pad, use_divisor);
328-
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
329-
break;
330-
}
331-
case MemoryFormat::Contiguous: {
332-
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
333-
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
334-
count,
335-
input_data,
336-
nbatch,
337-
nInputPlane,
338-
inputHeight, inputWidth,
339-
outputHeight, outputWidth,
340-
kH, kW,
341-
dH, dW,
342-
padH, padW,
343-
output_data,
344-
divisor_override_value,
345-
count_include_pad, use_divisor);
346-
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
347-
break;
339+
break;
340+
}
341+
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
348342
}
349-
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
350343
}
351-
}
352-
);
344+
);
345+
}
353346

354347
if (input.ndimension() == 3) {
355348
output.resize_({nInputPlane, outputHeight, outputWidth});
@@ -395,14 +388,6 @@ Tensor& avg_pool2d_backward_out_cuda_template(
395388
"divisor must be not zero");
396389

397390
const auto memory_format = input_.suggest_memory_format();
398-
if (memory_format == at::MemoryFormat::ChannelsLast) {
399-
TORCH_CHECK(input_.ndimension() == 4,
400-
"non-empty 4D (batch mode) tensor expected for input with channels_last layout");
401-
} else {
402-
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
403-
"non-empty 3D or 4D (batch mode) tensor expected for input");
404-
}
405-
406391
const Tensor input = input_.contiguous(memory_format);
407392
const Tensor gradOutput = gradOutput_.contiguous(memory_format);
408393

@@ -421,11 +406,14 @@ Tensor& avg_pool2d_backward_out_cuda_template(
421406
kH, kW, dH, dW, padH, padW,
422407
nInputPlane,
423408
inputHeight, inputWidth,
424-
outputHeight, outputWidth);
409+
outputHeight, outputWidth, memory_format);
425410

426411
gradInput.resize_as_(input);
427-
428412
const int32_t count = safe_downcast<int32_t, int64_t>(input.numel());
413+
if (count == 0) {
414+
return gradInput;
415+
}
416+
429417
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
430418
const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t>(count, num_threads);
431419

aten/src/ATen/native/cuda/DilatedMaxPool2d.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ void max_pool2d_with_indices_out_cuda_template(
346346
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
347347
nInputPlane,
348348
inputHeight, inputWidth,
349-
outputHeight, outputWidth);
349+
outputHeight, outputWidth, memory_format);
350350

351351
Tensor input = input_.contiguous(memory_format);
352352

@@ -513,7 +513,7 @@ void max_pool2d_with_indices_backward_out_cuda_template(
513513
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
514514
nInputPlane,
515515
inputHeight, inputWidth,
516-
outputHeight, outputWidth,
516+
outputHeight, outputWidth, memory_format,
517517
/*cuda=*/ true);
518518

519519
const Tensor gradOutput = gradOutput_.contiguous(memory_format);

aten/src/ATen/native/vulkan/VulkanAten.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ Tensor avg_pool2d(
164164
pooling_output_shape<int64_t>(iW, kW, padW, dW, 1, ceil_mode);
165165

166166
pool2d_shape_check(
167-
self, kH, kW, dH, dW, padH, padW, 1, 1, iC, iH, iW, oH, oW);
167+
self, kH, kW, dH, dW, padH, padW, 1, 1, iC, iH, iW, oH, oW, self.suggest_memory_format());
168168

169169
VulkanTensor y{{iN, iC, oH, oW}};
170170
vulkan::detail::avg_pool2d(
@@ -234,7 +234,8 @@ Tensor max_pool2d(
234234
iH,
235235
iW,
236236
oH,
237-
oW);
237+
oW,
238+
self.suggest_memory_format());
238239

239240
VulkanTensor y{{iN, iC, oH, oW}};
240241
vulkan::detail::max_pool2d(

test/test_nn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10463,6 +10463,19 @@ def test_convTranspose_empty(self, device):
1046310463
with torch.backends.cudnn.flags(enabled=False):
1046410464
self._test_module_empty_input(mod, inp, check_size=False)
1046510465

10466+
def test_AvgPool2d_empty(self, device):
10467+
avgpool = torch.nn.AvgPool2d(3, stride=2).to(device)
10468+
inp = torch.randn(0, 16, 20, 32, device=device)
10469+
self._test_module_empty_input(avgpool, inp, check_size=False)
10470+
10471+
clast_inp = torch.randn(0, 16, 20, 32, device=device).contiguous(memory_format=torch.channels_last)
10472+
self._test_module_empty_input(avgpool, clast_inp, check_size=False)
10473+
10474+
# test with empty non-batch input
10475+
with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
10476+
inp = torch.randn(16, 0, 20, 32, device=device)
10477+
avgpool(inp)
10478+
1046610479
@onlyCUDA
1046710480
@largeTensorTest('16GB')
1046810481
def test_prelu_backward_32bit_indexing(self, device):

0 commit comments

Comments
 (0)