Skip to content

Commit df0427a

Browse files
author
Aapo Kyrola
committed
[mkldnn][pooling] support ceil mode by padding changes
1 parent c8083e0 commit df0427a

File tree

7 files changed

+88
-30
lines changed

7 files changed

+88
-30
lines changed

aten/src/ATen/native/mkldnn/Pooling.cpp

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,29 +73,69 @@ static Tensor _mkldnn_pool2d(
7373
IntArrayRef dilation,
7474
bool ceil_mode,
7575
ideep::algorithm algo) {
76-
TORCH_CHECK(!ceil_mode, "Currently Mkldnn Pooling operators do not support ceil_mode.");
7776
auto kernel_size_vec = expand_param_if_needed(kernel_size, "kernel_size", 2);
7877
auto stride_vec = expand_param_if_needed(stride, "stride", 2);
7978
auto padding_vec = expand_param_if_needed(padding, "padding", 2);
79+
auto padding_vec_l = padding_vec;
80+
auto padding_vec_r = padding_vec;
8081
auto dilation_vec = expand_param_if_needed(dilation, "dilation", 2);
8182

8283
const ideep::tensor& x = itensor_from_mkldnn(input);
83-
const std::vector<int64_t> output_sizes = pool_output_sizes(
84-
input.sizes(),
85-
kernel_size_vec,
86-
stride_vec,
87-
padding_vec,
88-
dilation_vec,
89-
ceil_mode);
84+
std::vector<int64_t> output_sizes;
85+
86+
if (ceil_mode) {
87+
// MKLDNN does not support ceil mode, so we adjust padding
88+
// on the right side to match behavior. Adjust output size
89+
// accordingly.
90+
const std::vector<int64_t> output_sizes_ceil = pool_output_sizes(
91+
input.sizes(),
92+
kernel_size_vec,
93+
stride_vec,
94+
padding_vec_l,
95+
padding_vec_r,
96+
dilation_vec,
97+
true /* ceil_mode */);
98+
99+
// adjust padding until output sizes agree
100+
bool all_equal = false;
101+
while (!all_equal) {
102+
output_sizes = pool_output_sizes(
103+
input.sizes(),
104+
kernel_size_vec,
105+
stride_vec,
106+
padding_vec_l,
107+
padding_vec_r,
108+
dilation_vec,
109+
false /*ceil_mode */);
110+
111+
all_equal = true;
112+
for (size_t i = 2; i < input.sizes().size(); ++i) {
113+
if (output_sizes[i] < output_sizes_ceil[i]) {
114+
padding_vec_r[i - 2]++;
115+
all_equal = false;
116+
}
117+
}
118+
}
119+
} else {
120+
output_sizes = pool_output_sizes(
121+
input.sizes(),
122+
kernel_size_vec,
123+
stride_vec,
124+
padding_vec_l,
125+
padding_vec_r,
126+
dilation_vec,
127+
false /*ceil_mode */);
128+
}
129+
90130
ideep::tensor y;
91131
ideep::pooling_forward::compute<AllocForMKLDNN>(
92132
x,
93133
{output_sizes.cbegin(), output_sizes.cend()},
94134
y,
95135
{stride_vec.cbegin(), stride_vec.cend()},
96136
{kernel_size_vec.cbegin(), kernel_size_vec.cend()},
97-
{padding_vec.cbegin(), padding_vec.cend()},
98-
{padding_vec.cbegin(), padding_vec.cend()},
137+
{padding_vec_l.cbegin(), padding_vec_l.cend()},
138+
{padding_vec_r.cbegin(), padding_vec_r.cend()},
99139
algo,
100140
ideep::prop_kind::forward);
101141

aten/src/ATen/native/mkldnn/Utils.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,21 @@ std::vector<int64_t> pool_output_sizes(
2525
IntArrayRef input_size,
2626
IntArrayRef kernel_size,
2727
IntArrayRef stride,
28-
IntArrayRef padding,
28+
IntArrayRef padding_l,
29+
IntArrayRef padding_r,
2930
IntArrayRef dilation,
3031
bool ceil_mode) {
3132
std::vector<int64_t> output_size(input_size.size());
3233
// copy N and C
3334
output_size[0] = input_size[0];
3435
output_size[1] = input_size[1];
3536

36-
for (int i = 2; i < input_size.size(); ++i) {
37-
output_size[i] = pooling_output_shape<int64_t>(
37+
for (size_t i = 2; i < input_size.size(); ++i) {
38+
output_size[i] = pooling_output_shape_pad_lr<int64_t>(
3839
input_size[i],
3940
kernel_size[i - 2],
40-
padding[i - 2],
41+
padding_l[i - 2],
42+
padding_r[i - 2],
4143
stride[i - 2],
4244
dilation[i - 2],
4345
ceil_mode

aten/src/ATen/native/mkldnn/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ std::vector<int64_t> pool_output_sizes(
1616
IntArrayRef input_size,
1717
IntArrayRef kernel_size,
1818
IntArrayRef stride,
19-
IntArrayRef padding,
19+
IntArrayRef padding_l,
20+
IntArrayRef padding_r,
2021
IntArrayRef dilation,
2122
bool ceil_mode);
2223
}}

aten/src/THNN/generic/pooling_shape.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,26 @@
22
#define THNN_POOLING_SHAPE_H
33

44
template<typename T>
5-
static inline T pooling_output_shape(
6-
T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
7-
T outputSize = ((inputSize + 2 * pad - dilation * (kernelSize - 1) - 1 + (ceil_mode ? stride - 1 : 0)) / stride + 1);
8-
if (pad) {
5+
static inline T pooling_output_shape_pad_lr(
6+
T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
7+
bool ceil_mode
8+
) {
9+
T outputSize = ((inputSize + pad_l + pad_r - dilation * (kernelSize - 1)
10+
- 1 + (ceil_mode ? stride - 1 : 0)) / stride + 1);
11+
if (pad_l) {
912
// ensure that the last pooling starts inside the image
1013
// needed to avoid problems in ceil mode
11-
if ((outputSize - 1) * stride >= inputSize + pad)
14+
if ((outputSize - 1) * stride >= inputSize + pad_l)
1215
--outputSize;
1316
}
1417
return outputSize;
1518
}
1619

20+
template<typename T>
21+
static inline T pooling_output_shape(
22+
T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
23+
return pooling_output_shape_pad_lr(
24+
inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
25+
}
26+
1727
#endif

test/test_mkldnn.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,21 @@ def test_relu_(self):
126126
def test_max_pool2d(self):
127127
N = torch.randint(3, 10, (1,)).item()
128128
C = torch.randint(3, 10, (1,)).item()
129-
x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
130129

131-
max_pool2d = torch.nn.MaxPool2d(
132-
kernel_size=3,
133-
stride=2,
134-
padding=1)
130+
for stride in [1, 2, 3]:
131+
for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
132+
x = torch.randn(N, C, H, W, dtype=torch.float32) * 10
135133

136-
self.assertEqual(
137-
max_pool2d(x),
138-
max_pool2d(x.to_mkldnn()).to_dense())
134+
for ceil_mode in [False, True]:
135+
max_pool2d = torch.nn.MaxPool2d(
136+
kernel_size=3 if not ceil_mode else 7,
137+
stride=stride,
138+
padding=1,
139+
ceil_mode=ceil_mode)
140+
141+
self.assertEqual(
142+
max_pool2d(x),
143+
max_pool2d(x.to_mkldnn()).to_dense())
139144

140145
def test_avg_pool2d(self):
141146
N = torch.randint(3, 10, (1,)).item()

0 commit comments

Comments
 (0)