Skip to content

Commit 982ae98

Browse files
Natalia Gimelsheinfacebook-github-bot
authored andcommitted
Revert D24941350: [pytorch][PR] Reopen PR for 0 dim batch size for AvgPool2d.
Test Plan: revert-hammer Differential Revision: D24941350 (ceeab70) Original commit changeset: b7e50346d86e fbshipit-source-id: 2e42e4418476658dc1afb905184841bf61688cfd
1 parent c543b3b commit 982ae98

File tree

7 files changed

+77
-86
lines changed

7 files changed

+77
-86
lines changed

aten/src/ATen/native/AveragePool2d.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ void avg_pool2d_out_cpu_template(
119119
const int padH = safe_downcast<int, int64_t>(padding[0]);
120120
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
121121

122+
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
123+
"non-empty 2D or 3D (batch mode) tensor expected for input");
124+
122125
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
123126
"divisor must be not zero");
124127

@@ -136,7 +139,7 @@ void avg_pool2d_out_cpu_template(
136139
kH, kW, dH, dW, padH, padW, 1, 1,
137140
nInputPlane,
138141
inputHeight, inputWidth,
139-
outputHeight, outputWidth, input_.suggest_memory_format());
142+
outputHeight, outputWidth);
140143

141144
if (input_.ndimension() == 3) {
142145
output.resize_({nInputPlane, outputHeight, outputWidth});
@@ -273,8 +276,12 @@ Tensor& avg_pool2d_backward_out_cpu_template(
273276
"avg_pool2d: padding must either be a single int, or a tuple of two ints");
274277
const int padH = safe_downcast<int, int64_t>(padding[0]);
275278
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
279+
276280
const int64_t ndim = input.ndimension();
277281

282+
TORCH_CHECK((ndim == 3 || ndim == 4),
283+
"non-empty 3D or 4D (batch mode) tensor expected for input");
284+
278285
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
279286

280287
/* sizes */
@@ -292,8 +299,7 @@ Tensor& avg_pool2d_backward_out_cpu_template(
292299
kH, kW, dH, dW, padH, padW,
293300
nInputPlane,
294301
inputHeight, inputWidth,
295-
outputHeight, outputWidth,
296-
input.suggest_memory_format());
302+
outputHeight, outputWidth);
297303

298304
/* get contiguous gradOutput */
299305
const Tensor gradOutput = gradOutput_.contiguous();

aten/src/ATen/native/DilatedMaxPool2d.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void max_pool2d_with_indices_out_cpu_template(
169169
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
170170
nInputPlane,
171171
inputHeight, inputWidth,
172-
outputHeight, outputWidth, input_.suggest_memory_format());
172+
outputHeight, outputWidth);
173173

174174
/* get contiguous input */
175175
Tensor input = input_.contiguous();
@@ -360,8 +360,7 @@ Tensor& max_pool2d_with_indices_backward_out_cpu_template(
360360
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
361361
nInputPlane,
362362
inputHeight, inputWidth,
363-
outputHeight_for_shape_check, outputWidth_for_shape_check,
364-
input.suggest_memory_format());
363+
outputHeight_for_shape_check, outputWidth_for_shape_check);
365364

366365
/* backprop */
367366
if (input.ndimension() == 3)

aten/src/ATen/native/Pool.h

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ pool2d_shape_check(
5454
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
5555
int64_t nInputPlane,
5656
int64_t inputHeight, int64_t inputWidth,
57-
int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
57+
int64_t outputHeight, int64_t outputWidth)
5858
{
59+
const int64_t ndim = input.ndimension();
5960
const int64_t nOutputPlane = nInputPlane;
6061

6162
TORCH_CHECK(kW > 0 && kH > 0,
@@ -68,19 +69,8 @@ pool2d_shape_check(
6869
"dilation should be greater than zero, but got ",
6970
"dilationH: ", dilationH, " dilationW: ", dilationW);
7071

71-
bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
72-
if (memory_format == at::MemoryFormat::ChannelsLast){
73-
// Expect tensor in NHWC format and allow 0-dim only for N.
74-
TORCH_CHECK((input.ndimension() == 4 && valid_dims && input.size(3) != 0),
75-
"Expected 4D (batch mode) tensor expected for input with channels_last layout"
76-
" with optional 0 dim batch size for input, but got: ", input.sizes());
77-
} else {
78-
TORCH_CHECK((input.ndimension() == 3 && input.size(0) != 0 && valid_dims) ||
79-
(input.ndimension() == 4 && valid_dims && input.size(3) != 0),
80-
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
81-
input.sizes());
82-
}
83-
72+
TORCH_CHECK(input.numel() > 0 && (ndim == 3 || ndim == 4),
73+
"non-empty 3D or 4D input tensor expected but got ndim: ", ndim);
8474
TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
8575
"pad should be smaller than half of kernel size, but got ",
8676
"padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
@@ -103,13 +93,13 @@ max_pool2d_backward_shape_check(
10393
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
10494
int64_t nInputPlane,
10595
int64_t inputHeight, int64_t inputWidth,
106-
int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format,
96+
int64_t outputHeight, int64_t outputWidth,
10797
bool cuda=false)
10898
{
10999
pool2d_shape_check(
110100
input,
111101
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
112-
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
102+
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
113103

114104
const int64_t ndim = input.ndimension();
115105
const int64_t nOutputPlane = nInputPlane;
@@ -132,14 +122,12 @@ avg_pool2d_backward_shape_check(
132122
int kH, int kW, int dH, int dW, int padH, int padW,
133123
int64_t nInputPlane,
134124
int64_t inputHeight, int64_t inputWidth,
135-
int64_t outputHeight, int64_t outputWidth,
136-
MemoryFormat memory_format)
125+
int64_t outputHeight, int64_t outputWidth)
137126
{
138127
pool2d_shape_check(
139128
input,
140129
kH, kW, dH, dW, padH, padW, 1, 1,
141-
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
142-
memory_format);
130+
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
143131

144132
const int64_t ndim = input.ndimension();
145133
const int64_t nOutputPlane = nInputPlane;

aten/src/ATen/native/cuda/AveragePool2d.cu

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,14 @@ void avg_pool2d_out_cuda_template(
262262
const int padH = safe_downcast<int, int64_t>(padding[0]);
263263
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
264264

265+
const auto memory_format = input_.suggest_memory_format();
266+
if (memory_format == at::MemoryFormat::ChannelsLast){
267+
TORCH_CHECK(input_.ndimension() == 4,
268+
"non-empty 4D (batch mode) tensor expected for input with channels_last layout");
269+
} else {
270+
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
271+
"non-empty 3D or 4D (batch mode) tensor expected for input");
272+
}
265273

266274
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
267275
"divisor must be not zero");
@@ -273,14 +281,13 @@ void avg_pool2d_out_cuda_template(
273281

274282
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
275283
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
276-
const auto memory_format = input_.suggest_memory_format();
277284

278285
pool2d_shape_check(
279286
input_,
280287
kH, kW, dH, dW, padH, padW, 1, 1,
281288
nInputPlane,
282289
inputHeight, inputWidth,
283-
outputHeight, outputWidth, memory_format);
290+
outputHeight, outputWidth);
284291

285292
Tensor input = input_.contiguous(memory_format);
286293

@@ -293,36 +300,18 @@ void avg_pool2d_out_cuda_template(
293300
bool use_divisor = divisor_override.has_value();
294301
const auto divisor_override_value = use_divisor ? divisor_override.value() : 0;
295302

296-
if (count != 0) {
297-
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
298-
"avg_pool2d_out_cuda_frame",
299-
[&] {
300-
using accscalar_t = acc_type<scalar_t, true>;
301-
302-
scalar_t *output_data = output.data_ptr<scalar_t>();
303-
scalar_t *input_data = input.data_ptr<scalar_t>();
304-
305-
switch (memory_format){
306-
case MemoryFormat::ChannelsLast: {
307-
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
308-
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
309-
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
310-
count,
311-
input_data,
312-
nbatch,
313-
nInputPlane,
314-
inputHeight, inputWidth,
315-
outputHeight, outputWidth,
316-
kH, kW,
317-
dH, dW,
318-
padH, padW,
319-
output_data,
320-
divisor_override_value,
321-
count_include_pad, use_divisor);
322-
break;
323-
}
324-
case MemoryFormat::Contiguous: {
325-
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
303+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
304+
"avg_pool2d_out_cuda_frame",
305+
[&] {
306+
using accscalar_t = acc_type<scalar_t, true>;
307+
308+
scalar_t *output_data = output.data_ptr<scalar_t>();
309+
scalar_t *input_data = input.data_ptr<scalar_t>();
310+
311+
switch (memory_format){
312+
case MemoryFormat::ChannelsLast: {
313+
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
314+
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
326315
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
327316
count,
328317
input_data,
@@ -336,13 +325,31 @@ void avg_pool2d_out_cuda_template(
336325
output_data,
337326
divisor_override_value,
338327
count_include_pad, use_divisor);
339-
break;
340-
}
341-
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
328+
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
329+
break;
330+
}
331+
case MemoryFormat::Contiguous: {
332+
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
333+
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
334+
count,
335+
input_data,
336+
nbatch,
337+
nInputPlane,
338+
inputHeight, inputWidth,
339+
outputHeight, outputWidth,
340+
kH, kW,
341+
dH, dW,
342+
padH, padW,
343+
output_data,
344+
divisor_override_value,
345+
count_include_pad, use_divisor);
346+
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
347+
break;
342348
}
349+
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
343350
}
344-
);
345-
}
351+
}
352+
);
346353

347354
if (input.ndimension() == 3) {
348355
output.resize_({nInputPlane, outputHeight, outputWidth});
@@ -388,6 +395,14 @@ Tensor& avg_pool2d_backward_out_cuda_template(
388395
"divisor must be not zero");
389396

390397
const auto memory_format = input_.suggest_memory_format();
398+
if (memory_format == at::MemoryFormat::ChannelsLast) {
399+
TORCH_CHECK(input_.ndimension() == 4,
400+
"non-empty 4D (batch mode) tensor expected for input with channels_last layout");
401+
} else {
402+
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
403+
"non-empty 3D or 4D (batch mode) tensor expected for input");
404+
}
405+
391406
const Tensor input = input_.contiguous(memory_format);
392407
const Tensor gradOutput = gradOutput_.contiguous(memory_format);
393408

@@ -406,14 +421,11 @@ Tensor& avg_pool2d_backward_out_cuda_template(
406421
kH, kW, dH, dW, padH, padW,
407422
nInputPlane,
408423
inputHeight, inputWidth,
409-
outputHeight, outputWidth, memory_format);
424+
outputHeight, outputWidth);
410425

411426
gradInput.resize_as_(input);
427+
412428
const int32_t count = safe_downcast<int32_t, int64_t>(input.numel());
413-
if (count == 0) {
414-
return gradInput;
415-
}
416-
417429
const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
418430
const uint32_t num_blocks = cuda::ATenCeilDiv<uint32_t>(count, num_threads);
419431

aten/src/ATen/native/cuda/DilatedMaxPool2d.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ void max_pool2d_with_indices_out_cuda_template(
346346
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
347347
nInputPlane,
348348
inputHeight, inputWidth,
349-
outputHeight, outputWidth, memory_format);
349+
outputHeight, outputWidth);
350350

351351
Tensor input = input_.contiguous(memory_format);
352352

@@ -513,7 +513,7 @@ void max_pool2d_with_indices_backward_out_cuda_template(
513513
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
514514
nInputPlane,
515515
inputHeight, inputWidth,
516-
outputHeight, outputWidth, memory_format,
516+
outputHeight, outputWidth,
517517
/*cuda=*/ true);
518518

519519
const Tensor gradOutput = gradOutput_.contiguous(memory_format);

aten/src/ATen/native/vulkan/VulkanAten.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ Tensor avg_pool2d(
164164
pooling_output_shape<int64_t>(iW, kW, padW, dW, 1, ceil_mode);
165165

166166
pool2d_shape_check(
167-
self, kH, kW, dH, dW, padH, padW, 1, 1, iC, iH, iW, oH, oW, self.suggest_memory_format());
167+
self, kH, kW, dH, dW, padH, padW, 1, 1, iC, iH, iW, oH, oW);
168168

169169
VulkanTensor y{{iN, iC, oH, oW}};
170170
vulkan::detail::avg_pool2d(
@@ -234,8 +234,7 @@ Tensor max_pool2d(
234234
iH,
235235
iW,
236236
oH,
237-
oW,
238-
self.suggest_memory_format());
237+
oW);
239238

240239
VulkanTensor y{{iN, iC, oH, oW}};
241240
vulkan::detail::max_pool2d(

test/test_nn.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10463,19 +10463,6 @@ def test_convTranspose_empty(self, device):
1046310463
with torch.backends.cudnn.flags(enabled=False):
1046410464
self._test_module_empty_input(mod, inp, check_size=False)
1046510465

10466-
def test_AvgPool2d_empty(self, device):
10467-
avgpool = torch.nn.AvgPool2d(3, stride=2).to(device)
10468-
inp = torch.randn(0, 16, 20, 32, device=device)
10469-
self._test_module_empty_input(avgpool, inp, check_size=False)
10470-
10471-
clast_inp = torch.randn(0, 16, 20, 32, device=device).contiguous(memory_format=torch.channels_last)
10472-
self._test_module_empty_input(avgpool, clast_inp, check_size=False)
10473-
10474-
# test with empty non-batch input
10475-
with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
10476-
inp = torch.randn(16, 0, 20, 32, device=device)
10477-
avgpool(inp)
10478-
1047910466
@onlyCUDA
1048010467
@largeTensorTest('16GB')
1048110468
def test_prelu_backward_32bit_indexing(self, device):

0 commit comments

Comments
 (0)