Skip to content

Commit 2372e7e

Browse files
skrahfacebook-github-bot
authored andcommitted
DilatedMaxPool: expand incomplete kernel_size for the C++ API (#22073)
Summary: Fixes #22032. Pull Request resolved: #22073 Differential Revision: D15944471 Pulled By: mrshenli fbshipit-source-id: 84b265be00d67aa7f13508ede0646763d2339f1d
1 parent b2a3931 commit 2372e7e

File tree

4 files changed

+28
-36
lines changed

4 files changed

+28
-36
lines changed

aten/src/ATen/native/DilatedMaxPool2d.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,8 @@ void max_pool2d_with_indices_out_cpu_template(
129129
IntArrayRef dilation,
130130
bool ceil_mode)
131131
{
132-
// XXX JIT: Pooling.cpp allows stride.empty().
133-
// XXX IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
134-
TORCH_CHECK(kernel_size.size() == 2 &&
132+
// #20866, #22032: Guarantee this for the official C++ API?
133+
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 2) &&
135134
(stride.empty() || stride.size() == 2) &&
136135
(padding.size() == 1 || padding.size() == 2) &&
137136
(dilation.size() == 1 || dilation.size() == 2),
@@ -141,7 +140,7 @@ void max_pool2d_with_indices_out_cpu_template(
141140
"non-empty 3D or 4D (batch mode) tensor expected for input");
142141

143142
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
144-
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
143+
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
145144

146145
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
147146
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
@@ -303,9 +302,8 @@ Tensor& max_pool2d_with_indices_backward_out_cpu_template(
303302
IntArrayRef dilation,
304303
bool ceil_mode)
305304
{
306-
// XXX JIT: Pooling.cpp allows stride.empty().
307-
// XXX IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
308-
TORCH_CHECK(kernel_size.size() == 2 &&
305+
// #20866, #22032: Guarantee this for the official C++ API?
306+
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 2) &&
309307
(stride.empty() || stride.size() == 2) &&
310308
(padding.size() == 1 || padding.size() == 2) &&
311309
(dilation.size() == 1 || dilation.size() == 2),
@@ -315,7 +313,7 @@ Tensor& max_pool2d_with_indices_backward_out_cpu_template(
315313
"non-empty 3D or 4D (batch mode) tensor expected for input");
316314

317315
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
318-
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
316+
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
319317

320318
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
321319
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);

aten/src/ATen/native/DilatedMaxPool3d.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,8 @@ void max_pool3d_with_indices_out_cpu_template(
148148
IntArrayRef dilation,
149149
bool ceil_mode)
150150
{
151-
// XXX [JIT] Pooling.cpp allows stride.empty().
152-
// XXX [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
153-
TORCH_CHECK(kernel_size.size() == 3 &&
151+
// #20866, #22032: Guarantee this for the official C++ API?
152+
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 3) &&
154153
(stride.empty() || stride.size() == 3) &&
155154
(padding.size() == 1 || padding.size() == 3) &&
156155
(dilation.size() == 1 || dilation.size() == 3),
@@ -160,8 +159,8 @@ void max_pool3d_with_indices_out_cpu_template(
160159
"non-empty 4D or 5D (batch mode) tensor expected for input");
161160

162161
const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
163-
const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
164-
const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
162+
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
163+
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
165164

166165
const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
167166
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
@@ -353,9 +352,8 @@ Tensor& max_pool3d_with_indices_backward_out_cpu_template(
353352
IntArrayRef dilation,
354353
bool ceil_mode)
355354
{
356-
// XXX [JIT] Pooling.cpp allows stride.empty().
357-
// XXX [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
358-
TORCH_CHECK(kernel_size.size() == 3 &&
355+
// #20866, #22032: Guarantee this for the official C++ API?
356+
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 3) &&
359357
(stride.empty() || stride.size() == 3) &&
360358
(padding.size() == 1 || padding.size() == 3) &&
361359
(dilation.size() == 1 || dilation.size() == 3),
@@ -365,8 +363,8 @@ Tensor& max_pool3d_with_indices_backward_out_cpu_template(
365363
"non-empty 4D or 5D (batch mode) tensor expected for input");
366364

367365
const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
368-
const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
369-
const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
366+
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
367+
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
370368

371369
const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
372370
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);

aten/src/ATen/native/cuda/DilatedMaxPool2d.cu

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,8 @@ void max_pool2d_with_indices_out_cuda_template(
146146
checkAllSameGPU("max_pool2d_with_indices_out_cuda",
147147
{output_arg, indices_arg, input_arg});
148148

149-
// XXX JIT: Pooling.cpp allows stride.empty().
150-
// XXX IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
151-
TORCH_CHECK(kernel_size.size() == 2 &&
149+
// #20866, #22032: Guarantee this for the official C++ API?
150+
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 2) &&
152151
(stride.empty() || stride.size() == 2) &&
153152
(padding.size() == 1 || padding.size() == 2) &&
154153
(dilation.size() == 1 || dilation.size() == 2),
@@ -158,7 +157,7 @@ void max_pool2d_with_indices_out_cuda_template(
158157
"non-empty 3D or 4D (batch mode) tensor expected for input");
159158

160159
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
161-
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
160+
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
162161

163162
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
164163
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
@@ -237,9 +236,8 @@ void max_pool2d_with_indices_backward_out_cuda_template(
237236
checkAllSameGPU("max_pool2d_with_indices_out_cuda",
238237
{gradInput_arg, gradOutput_arg, input_arg, indices_arg});
239238

240-
// XXX JIT: Pooling.cpp allows stride.empty().
241-
// XXX IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
242-
TORCH_CHECK(kernel_size.size() == 2 &&
239+
// #20866, #22032: Guarantee this for the official C++ API?
240+
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 2) &&
243241
(stride.empty() || stride.size() == 2) &&
244242
(padding.size() == 1 || padding.size() == 2) &&
245243
(dilation.size() == 1 || dilation.size() == 2),
@@ -249,7 +247,7 @@ void max_pool2d_with_indices_backward_out_cuda_template(
249247
"non-empty 3D or 4D (batch mode) tensor expected for input");
250248

251249
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
252-
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
250+
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
253251

254252
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
255253
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);

aten/src/ATen/native/cuda/DilatedMaxPool3d.cu

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,8 @@ void max_pool3d_with_indices_out_cuda_template(
290290
checkAllSameGPU("max_pool3d_with_indices_out_cuda",
291291
{output_arg, indices_arg, input_arg});
292292

293-
// XXX [JIT] Pooling.cpp allows stride.empty().
294-
// XXX [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
295-
TORCH_CHECK(kernel_size.size() == 3 &&
293+
// #20866, #22032: Guarantee this for the official C++ API?
294+
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 3) &&
296295
(stride.empty() || stride.size() == 3) &&
297296
(padding.size() == 1 || padding.size() == 3) &&
298297
(dilation.size() == 1 || dilation.size() == 3),
@@ -302,8 +301,8 @@ void max_pool3d_with_indices_out_cuda_template(
302301
"non-empty 4D or 5D (batch mode) tensor expected for input");
303302

304303
const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
305-
const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
306-
const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
304+
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
305+
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
307306

308307
const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
309308
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
@@ -395,9 +394,8 @@ void max_pool3d_with_indices_backward_out_cuda_template(
395394
checkAllSameGPU("max_pool3d_with_indices_backward_out_cuda",
396395
{gradInput_arg, gradOutput_arg, input_arg, indices_arg});
397396

398-
// XXX [JIT] Pooling.cpp allows stride.empty().
399-
// XXX [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1 && dilation.size() == 1.
400-
TORCH_CHECK(kernel_size.size() == 3 &&
397+
// #20866, #22032: Guarantee this for the official C++ API?
398+
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 3) &&
401399
(stride.empty() || stride.size() == 3) &&
402400
(padding.size() == 1 || padding.size() == 3) &&
403401
(dilation.size() == 1 || dilation.size() == 3),
@@ -414,8 +412,8 @@ void max_pool3d_with_indices_backward_out_cuda_template(
414412
gradInput.zero_();
415413

416414
const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
417-
const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
418-
const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
415+
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
416+
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
419417

420418
const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
421419
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);

0 commit comments

Comments
 (0)