4242 module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
4343 ctcloss_reference, new_module_tests
4444from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \
45- dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, \
45+ dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
4646 skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, largeCUDATensorTest, onlyOnCPUAndCUDA, \
4747 deviceCountAtLeast, expectedAlertNondeterministic, largeTensorTest
4848from torch.nn import MultiheadAttention
@@ -9810,6 +9810,41 @@ def helper(n, c, h, w, kernel_size, stride=None,
98109810 helper(10, 512, 31, 31, 3, stride=2)
98119811 helper(1, 129, 8, 8, 3, stride=2)
98129812
9813+ @onlyCPU
9814+ @dtypes(torch.float)
9815+ def test_max_pool1d_errors(self, device, dtype):
9816+ def check(x, args, message):
9817+ model = torch.nn.MaxPool1d(*args)
9818+ with self.assertRaisesRegex(RuntimeError, r'max_pool1d\(\) ' + message):
9819+ model(torch.tensor(x, device=device, dtype=dtype))
9820+
9821+ # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
9822+ check(0, (1,), "input tensor must have 2 or 3 dimensions but got 0")
9823+ check([], (1,), "input tensor must have 2 or 3 dimensions but got 1")
9824+ check([[]], (1, 0), "stride must be greater than zero, but got 0")
9825+ check([[]], (1, 1, -1), "padding must be non-negative, but got -1")
9826+ check([[]], (1, 1, 2), "padding should be at most half of kernel size, but got padding=2 and kernel_size=1")
9827+ check([[]], (1, 1, 0, 0), "dilation must be greater than zero, but got 0")
9828+ check([[]], (5, 1, 0, 1), "Invalid computed output size: -4")
9829+
9830+ @onlyCPU
9831+ @dtypes(torch.float, torch.double)
9832+ def test_max_pool1d_corner_cases(self, device, dtype):
9833+ def check(x, args, expected):
9834+ model = torch.nn.MaxPool1d(*args)
9835+ tensor = torch.tensor(x, device=device, dtype=dtype)
9836+ self.assertEqual(model(tensor), torch.tensor(expected, device=device, dtype=dtype))
9837+
9838+ # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
9839+ check([[]], (1, None, 0, 1, False, False), [[]])
9840+ check([[[]]], (1, None, 0, 1, False, False), [[[]]])
9841+ check([[[]]], (2, 1, 1, 2, False, True), [[[]]])
9842+ check([[1]], (1, None, 0, 1, False, False), [[1]])
9843+ check([[1]], (2, None, 1, 2, False, False), [[float('-inf')]])
9844+ check([[1], [1]], (2, None, 1, 2, False, False), [[float('-inf')], [float('-inf')]])
9845+ check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]])
9846+ check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]])
9847+
98139848 @onlyCUDA
98149849 def test_max_pool2d(self, device):
98159850 def helper(n, c, h, w, ks):
@@ -11328,15 +11363,22 @@ def test_max_pool_nan_inf(self, device, dtype):
1132811363 for num_dim in [1, 2, 3]:
1132911364 fn_name = '{}max_pool{}d'.format(adaptive, num_dim)
1133011365 fn = getattr(F, fn_name)
11366+
1133111367 x = torch.full([1, 1] + num_dim * [3], nan, device=device, dtype=dtype, requires_grad=True)
1133211368 res = fn(x, 1 if adaptive else 3)
1133311369 res.backward(torch.randn_like(res))
1133411370 self.assertTrue(math.isnan(res.item()))
11371+ x.requires_grad_(False)
11372+ res = fn(x, 1 if adaptive else 3)
11373+ self.assertTrue(math.isnan(res.item()))
1133511374
1133611375 x2 = torch.full([1, 1] + num_dim * [3], -inf, device=device, dtype=dtype, requires_grad=True)
1133711376 res2 = fn(x2, 1 if adaptive else 3)
1133811377 res2.backward(torch.randn_like(res2))
1133911378 self.assertTrue(math.isinf(res2.item()))
11379+ x2.requires_grad_(False)
11380+ res2 = fn(x2, 1 if adaptive else 3)
11381+ self.assertTrue(math.isinf(res2.item()))
1134011382
1134111383 @onlyOnCPUAndCUDA
1134211384 @dtypes(torch.float, torch.double)
@@ -11373,12 +11415,12 @@ def test_pooling_zero_stride(self, device):
1137311415 fn_name = '{}_pool{}d'.format(op, num_dim)
1137411416 fn = getattr(F, fn_name)
1137511417 x = torch.ones([1, 2] + num_dim * [4], device=device, dtype=torch.float)
11376- self.assertRaisesRegex(RuntimeError, "stride should not be zero",
11418+ self.assertRaisesRegex(RuntimeError, r "stride should not be zero|stride must be greater than zero",
1137711419 lambda: fn(x, kernel_size=2, stride=0))
1137811420
1137911421 fn_module_name = '{}Pool{}d'.format(op.title(), num_dim)
1138011422 fn_module = getattr(nn, fn_module_name)(kernel_size=2, stride=0)
11381- self.assertRaisesRegex(RuntimeError, "stride should not be zero",
11423+ self.assertRaisesRegex(RuntimeError, r "stride should not be zero|stride must be greater than zero",
1138211424 lambda: fn_module(x))
1138311425
1138411426 @dtypesIfCUDA(*ALL_TENSORTYPES2)
@@ -11401,6 +11443,10 @@ def test_pool_invalid_size(self, device, dtype):
1140111443 for op in ('max', 'avg'):
1140211444 for num_dim in [1, 2, 3]:
1140311445 fn_name = '{}_pool{}d'.format(op, num_dim)
11446+ if op == 'max':
11447+ # New implementation without indices supports empty tensors
11448+ # TODO(Heitor) change once with_indices code is updated
11449+ fn_name += '_with_indices'
1140411450 fn = getattr(F, fn_name)
1140511451 # use a configuration that gives zero outputs only
1140611452 # when doing a correct floor division by the stride
0 commit comments