Skip to content

Commit fe2d9ec

Browse files
New MaxPool1d without indices implementation
ghstack-source-id: f25c08c Pull Request resolved: #43745
1 parent 9063bce commit fe2d9ec

File tree

5 files changed

+269
-15
lines changed

5 files changed

+269
-15
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/NamedTensorUtils.h>
3+
#include <ATen/Parallel.h>
4+
#include <ATen/native/DispatchStub.h>
5+
#include <ATen/native/MaxPooling.h>
6+
7+
namespace at {
8+
namespace native {
9+
10+
DEFINE_DISPATCH(max_pool1d_stub);
11+
12+
namespace {
13+
14+
// Compute the output size for the given pooling parameters
15+
inline int64_t output_size(
16+
int64_t input_size,
17+
int64_t kernel_size,
18+
int64_t stride,
19+
int64_t padding,
20+
int64_t dilation,
21+
bool ceil_mode) {
22+
int64_t num = input_size + 2 * padding - dilation * (kernel_size - 1) - 1;
23+
// Ensure last kernel window starts within bounds in ceil mode
24+
if (ceil_mode && stride - dilation * (kernel_size - 1) <= num % stride) {
25+
return (num + stride - 1) / stride + 1;
26+
}
27+
return num / stride + 1;
28+
}
29+
30+
Tensor max_pool1d_impl(
31+
const Tensor& self,
32+
IntArrayRef kernel_size,
33+
IntArrayRef stride,
34+
IntArrayRef padding,
35+
IntArrayRef dilation,
36+
bool ceil_mode) {
37+
NoNamesGuard guard;
38+
39+
TORCH_CHECK(
40+
self.dim() == 2 || self.dim() == 3,
41+
"max_pool1d() input tensor must have 2 or 3 dimensions but got ",
42+
self.dim());
43+
TORCH_CHECK(
44+
kernel_size.size() == 1,
45+
"max_pool1d() kernel_size must be an int or int list of size 1 but got size ",
46+
kernel_size.size());
47+
TORCH_CHECK(
48+
stride.size() == 0 || stride.size() == 1,
49+
"max_pool1d() stride must be None, an int or int list of size 1 but got size ",
50+
stride.size());
51+
TORCH_CHECK(
52+
padding.size() == 1,
53+
"max_pool1d() padding must be an int or int list of size 1 but got size ",
54+
padding.size());
55+
TORCH_CHECK(
56+
dilation.size() == 1,
57+
"max_pool1d() dilation must be an int or int list of size 1 but got size ",
58+
dilation.size());
59+
60+
// If stride=None then set it to kernel_size
61+
if (stride.empty()) {
62+
stride = kernel_size;
63+
}
64+
65+
const int64_t NB = self.dim() == 3 ? self.size(-3) : 1;
66+
const int64_t NC = self.size(-2);
67+
const int64_t IW = self.size(-1);
68+
const int64_t KW = kernel_size[0];
69+
const int64_t SJ = stride[0];
70+
const int64_t PJ = padding[0];
71+
const int64_t DJ = dilation[0];
72+
73+
TORCH_CHECK(
74+
KW > 0,
75+
"max_pool1d() kernel_size must be greater than zero, but got ",
76+
KW);
77+
TORCH_CHECK(
78+
SJ > 0, "max_pool1d() stride must be greater than zero, but got ", SJ);
79+
TORCH_CHECK(
80+
PJ >= 0, "max_pool1d() padding must be non-negative, but got ", PJ);
81+
TORCH_CHECK(
82+
PJ <= KW / 2,
83+
"max_pool1d() padding should be at most half of kernel size, but got padding=",
84+
PJ,
85+
" and kernel_size=",
86+
KW);
87+
TORCH_CHECK(
88+
DJ > 0, "max_pool1d() dilation must be greater than zero, but got ", DJ);
89+
90+
const int64_t OW = output_size(IW, KW, SJ, PJ, DJ, ceil_mode);
91+
TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW);
92+
Tensor output = at::empty({NB, NC, OW}, self.options());
93+
94+
PoolingParams1D params{NB, NC, IW, OW, KW, SJ, PJ, DJ};
95+
max_pool1d_stub(self.device().type(), output, self, params);
96+
97+
if (self.dim() == 2) {
98+
output.squeeze_(0);
99+
}
100+
101+
guard.reset();
102+
namedinference::propagate_names(output, self);
103+
104+
return output;
105+
}
106+
107+
} // namespace
108+
109+
Tensor max_pool1d(
110+
const Tensor& self,
111+
IntArrayRef kernel_size,
112+
IntArrayRef stride,
113+
IntArrayRef padding,
114+
IntArrayRef dilation,
115+
bool ceil_mode) {
116+
if (self.requires_grad() || !self.device().is_cpu()) {
117+
// Needs indices for grad and with_indices defines CUDA dispatch
118+
return std::get<0>(at::max_pool1d_with_indices(
119+
self, kernel_size, stride, padding, dilation, ceil_mode));
120+
}
121+
return max_pool1d_impl(
122+
self, kernel_size, stride, padding, dilation, ceil_mode);
123+
}
124+
125+
} // namespace native
126+
} // namespace at

aten/src/ATen/native/MaxPooling.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/native/DispatchStub.h>
5+
6+
namespace at {
7+
namespace native {
8+
9+
// TODO(Heitor) Template by dimension
10+
struct PoolingParams1D {
11+
int64_t NB; // Number of batches
12+
int64_t NC; // Number of channels
13+
int64_t IW; // Input width
14+
int64_t OW; // Output width
15+
int64_t KW; // Kernel width
16+
int64_t SJ; // Column stride
17+
int64_t PJ; // Column padding
18+
int64_t DJ; // Column dilation
19+
20+
// Return index of first output within bounds for this kernel index
21+
inline int64_t valid_kernel_start(int64_t kj) const {
22+
int64_t ij = kj * DJ - PJ;
23+
return ij < 0 ? (-ij + SJ - 1) / SJ : 0;
24+
}
25+
26+
// Return index one past last output within bounds for this kernel index
27+
inline int64_t valid_kernel_end(int64_t kj) const {
28+
int64_t ij = (OW - 1) * SJ + kj * DJ - PJ;
29+
return ij >= IW ? OW - (ij - IW + SJ) / SJ : OW;
30+
}
31+
};
32+
33+
using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
34+
35+
DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
36+
37+
} // namespace native
38+
} // namespace at

aten/src/ATen/native/Pooling.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,6 @@ Tensor avg_pool1d(
107107
return output.squeeze(2);
108108
}
109109

110-
Tensor max_pool1d(
111-
const Tensor& self,
112-
IntArrayRef kernel_size,
113-
IntArrayRef stride,
114-
IntArrayRef padding,
115-
IntArrayRef dilation,
116-
bool ceil_mode) {
117-
auto output_and_indices = at::max_pool1d_with_indices(
118-
self, kernel_size, stride, padding, dilation, ceil_mode);
119-
return std::get<0>(output_and_indices);
120-
}
121-
122110
Tensor max_pool2d(
123111
const Tensor& self,
124112
IntArrayRef kernel_size,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Parallel.h>
3+
#include <ATen/cpu/vec256/vec256.h>
4+
#include <ATen/native/MaxPooling.h>
5+
6+
namespace at {
7+
namespace native {
8+
9+
namespace {
10+
11+
template <typename scalar_t>
12+
inline void max_pool1d_kernel(
13+
scalar_t* op,
14+
const scalar_t* ip,
15+
const PoolingParams1D& p) {
16+
for (int64_t kj = 0; kj < p.KW; ++kj) {
17+
int64_t oj = p.valid_kernel_start(kj);
18+
int64_t oe = p.valid_kernel_end(kj);
19+
int64_t ij = oj * p.SJ + kj * p.DJ - p.PJ;
20+
for (; oj < oe; ++oj, ij += p.SJ) {
21+
bool update_max = std::isnan(ip[ij]) || op[oj] < ip[ij];
22+
op[oj] = update_max ? ip[ij] : op[oj];
23+
}
24+
}
25+
}
26+
27+
void max_pool1d_impl(
28+
Tensor& output,
29+
const Tensor& input,
30+
const PoolingParams1D& p) {
31+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool1d_impl", [&] {
32+
scalar_t* const OP = output.data_ptr<scalar_t>();
33+
const scalar_t* const IP = input.contiguous().data_ptr<scalar_t>();
34+
35+
// Value used for padding
36+
constexpr scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
37+
? -std::numeric_limits<scalar_t>::infinity()
38+
: std::numeric_limits<scalar_t>::lowest();
39+
40+
at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
41+
for (int64_t it = begin; it < end; ++it) {
42+
scalar_t* op = OP + it * p.OW;
43+
const scalar_t* ip = IP + it * p.IW;
44+
std::fill_n(op, p.OW, FILL);
45+
max_pool1d_kernel(op, ip, p);
46+
}
47+
});
48+
});
49+
}
50+
51+
} // namespace
52+
53+
REGISTER_DISPATCH(max_pool1d_stub, &max_pool1d_impl);
54+
55+
} // namespace native
56+
} // namespace at

test/test_nn.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
4343
ctcloss_reference, new_module_tests
4444
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \
45-
dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, \
45+
dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
4646
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, largeCUDATensorTest, onlyOnCPUAndCUDA, \
4747
deviceCountAtLeast, expectedAlertNondeterministic, largeTensorTest
4848
from torch.nn import MultiheadAttention
@@ -9810,6 +9810,41 @@ def helper(n, c, h, w, kernel_size, stride=None,
98109810
helper(10, 512, 31, 31, 3, stride=2)
98119811
helper(1, 129, 8, 8, 3, stride=2)
98129812

9813+
@onlyCPU
9814+
@dtypes(torch.float)
9815+
def test_max_pool1d_errors(self, device, dtype):
9816+
def check(x, args, message):
9817+
model = torch.nn.MaxPool1d(*args)
9818+
with self.assertRaisesRegex(RuntimeError, r'max_pool1d\(\) ' + message):
9819+
model(torch.tensor(x, device=device, dtype=dtype))
9820+
9821+
# Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
9822+
check(0, (1,), "input tensor must have 2 or 3 dimensions but got 0")
9823+
check([], (1,), "input tensor must have 2 or 3 dimensions but got 1")
9824+
check([[]], (1, 0), "stride must be greater than zero, but got 0")
9825+
check([[]], (1, 1, -1), "padding must be non-negative, but got -1")
9826+
check([[]], (1, 1, 2), "padding should be at most half of kernel size, but got padding=2 and kernel_size=1")
9827+
check([[]], (1, 1, 0, 0), "dilation must be greater than zero, but got 0")
9828+
check([[]], (5, 1, 0, 1), "Invalid computed output size: -4")
9829+
9830+
@onlyCPU
9831+
@dtypes(torch.float, torch.double)
9832+
def test_max_pool1d_corner_cases(self, device, dtype):
9833+
def check(x, args, expected):
9834+
model = torch.nn.MaxPool1d(*args)
9835+
tensor = torch.tensor(x, device=device, dtype=dtype)
9836+
self.assertEqual(model(tensor), torch.tensor(expected, device=device, dtype=dtype))
9837+
9838+
# Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
9839+
check([[]], (1, None, 0, 1, False, False), [[]])
9840+
check([[[]]], (1, None, 0, 1, False, False), [[[]]])
9841+
check([[[]]], (2, 1, 1, 2, False, True), [[[]]])
9842+
check([[1]], (1, None, 0, 1, False, False), [[1]])
9843+
check([[1]], (2, None, 1, 2, False, False), [[float('-inf')]])
9844+
check([[1], [1]], (2, None, 1, 2, False, False), [[float('-inf')], [float('-inf')]])
9845+
check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]])
9846+
check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]])
9847+
98139848
@onlyCUDA
98149849
def test_max_pool2d(self, device):
98159850
def helper(n, c, h, w, ks):
@@ -11328,15 +11363,22 @@ def test_max_pool_nan_inf(self, device, dtype):
1132811363
for num_dim in [1, 2, 3]:
1132911364
fn_name = '{}max_pool{}d'.format(adaptive, num_dim)
1133011365
fn = getattr(F, fn_name)
11366+
1133111367
x = torch.full([1, 1] + num_dim * [3], nan, device=device, dtype=dtype, requires_grad=True)
1133211368
res = fn(x, 1 if adaptive else 3)
1133311369
res.backward(torch.randn_like(res))
1133411370
self.assertTrue(math.isnan(res.item()))
11371+
x.requires_grad_(False)
11372+
res = fn(x, 1 if adaptive else 3)
11373+
self.assertTrue(math.isnan(res.item()))
1133511374

1133611375
x2 = torch.full([1, 1] + num_dim * [3], -inf, device=device, dtype=dtype, requires_grad=True)
1133711376
res2 = fn(x2, 1 if adaptive else 3)
1133811377
res2.backward(torch.randn_like(res2))
1133911378
self.assertTrue(math.isinf(res2.item()))
11379+
x2.requires_grad_(False)
11380+
res2 = fn(x2, 1 if adaptive else 3)
11381+
self.assertTrue(math.isinf(res2.item()))
1134011382

1134111383
@onlyOnCPUAndCUDA
1134211384
@dtypes(torch.float, torch.double)
@@ -11373,12 +11415,12 @@ def test_pooling_zero_stride(self, device):
1137311415
fn_name = '{}_pool{}d'.format(op, num_dim)
1137411416
fn = getattr(F, fn_name)
1137511417
x = torch.ones([1, 2] + num_dim * [4], device=device, dtype=torch.float)
11376-
self.assertRaisesRegex(RuntimeError, "stride should not be zero",
11418+
self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero",
1137711419
lambda: fn(x, kernel_size=2, stride=0))
1137811420

1137911421
fn_module_name = '{}Pool{}d'.format(op.title(), num_dim)
1138011422
fn_module = getattr(nn, fn_module_name)(kernel_size=2, stride=0)
11381-
self.assertRaisesRegex(RuntimeError, "stride should not be zero",
11423+
self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero",
1138211424
lambda: fn_module(x))
1138311425

1138411426
@dtypesIfCUDA(*ALL_TENSORTYPES2)
@@ -11401,6 +11443,10 @@ def test_pool_invalid_size(self, device, dtype):
1140111443
for op in ('max', 'avg'):
1140211444
for num_dim in [1, 2, 3]:
1140311445
fn_name = '{}_pool{}d'.format(op, num_dim)
11446+
if op == 'max':
11447+
# New implementation without indices supports empty tensors
11448+
# TODO(Heitor) change once with_indices code is updated
11449+
fn_name += '_with_indices'
1140411450
fn = getattr(F, fn_name)
1140511451
# use a configuration that gives zero outputs only
1140611452
# when doing a correct floor division by the stride

0 commit comments

Comments
 (0)