Skip to content

Commit 59bda9a

Browse files
authored
Fix reflection padding boundary checks (#6438)
* Fix Reflection padding boundary checks * Improve padding docs * fix lint
1 parent 65a8ac0 commit 59bda9a

File tree

8 files changed

+282
-68
lines changed

8 files changed

+282
-68
lines changed

aten/src/THCUNN/generic/SpatialReflectionPadding.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ void THNN_(SpatialReflectionPadding_updateOutput)(THCState *state,
3030
int inputH = THCTensor_(size)(state, input, dimh);
3131
int inputW = THCTensor_(size)(state, input, dimw);
3232

33-
THArgCheck(padL <= inputW && padR <= inputW, 4,
34-
"Padding size should not exceed corresponding input dimension, "
33+
THArgCheck(padL < inputW && padR < inputW, 4,
34+
"Padding size should be less than the corresponding input dimension, "
3535
"but got: padding (%d, %d) at dimension %d of input %s",
3636
padL, padR, dimw, THCTensor_(sizeDesc)(state, input).str);
3737

38-
THArgCheck(padT <= inputH && padB <= inputH, 6,
39-
"Padding size should not exceed corresponding input dimension, "
38+
THArgCheck(padT < inputH && padB < inputH, 6,
39+
"Padding size should be less than the corresponding input dimension, "
4040
"but got: padding (%d, %d) at dimension %d of input %s",
4141
padT, padB, dimh, THCTensor_(sizeDesc)(state, input).str);
4242

aten/src/THCUNN/generic/TemporalReflectionPadding.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ void THNN_(TemporalReflectionPadding_updateOutput)(THCState *state,
2626
int numPlanes = THCTensor_(size)(state, input, planeDim);
2727
int inputW = THCTensor_(size)(state, input, dimw);
2828

29-
THArgCheck(padL <= inputW && padR <= inputW, 4,
30-
"Padding size should not exceed corresponding input dimension, "
29+
THArgCheck(padL < inputW && padR < inputW, 4,
30+
"Padding size should be less than the corresponding input dimension, "
3131
"but got: padding (%d, %d) at dimension %d of input %s",
3232
padL, padR, dimw, THCTensor_(sizeDesc)(state, input).str);
3333

aten/src/THNN/generic/SpatialReflectionPadding.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ void THNN_(SpatialReflectionPadding_updateOutput)(THNNState *state,
8383
iheight = input->size[dimh];
8484
iwidth = input->size[dimw];
8585

86-
THArgCheck(pad_l <= iwidth && pad_r <= iwidth, 4,
87-
"Padding size should not exceed corresponding input dimension, "
86+
THArgCheck(pad_l < iwidth && pad_r < iwidth, 4,
87+
"Padding size should be less than the corresponding input dimension, "
8888
"but got: padding (%d, %d) at dimension %d of input %s",
8989
pad_l, pad_r, dimw, _THSizeDesc(input->size, input->nDimension).str);
9090

91-
THArgCheck(pad_t <= iheight && pad_b <= iheight, 6,
92-
"Padding size should not exceed corresponding input dimension, "
91+
THArgCheck(pad_t < iheight && pad_b < iheight, 6,
92+
"Padding size should be less than the corresponding input dimension, "
9393
"but got: padding (%d, %d) at dimension %d of input %s",
9494
pad_t, pad_b, dimh, _THSizeDesc(input->size, input->nDimension).str);
9595

aten/src/THNN/generic/TemporalReflectionPadding.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ void THNN_(TemporalReflectionPadding_updateOutput)(THNNState *state,
6464
nslices = input->size[dimslices];
6565
iwidth = input->size[dimw];
6666

67-
THArgCheck(pad_l <= iwidth && pad_r <= iwidth, 4,
68-
"Padding size should not exceed corresponding input dimension, "
67+
THArgCheck(pad_l < iwidth && pad_r < iwidth, 4,
68+
"Padding size should be less than the corresponding input dimension, "
6969
"but got: padding (%d, %d) at dimension %d of input %s",
7070
pad_l, pad_r, dimw, _THSizeDesc(input->size, input->nDimension).str);
7171

docs/source/nn.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,24 @@ Pooling Layers
196196
Padding Layers
197197
--------------
198198

199+
:hidden:`ReflectionPad1d`
200+
~~~~~~~~~~~~~~~~~~~~~~~~~
201+
202+
.. autoclass:: ReflectionPad1d
203+
:members:
204+
199205
:hidden:`ReflectionPad2d`
200206
~~~~~~~~~~~~~~~~~~~~~~~~~
201207

202208
.. autoclass:: ReflectionPad2d
203209
:members:
204210

211+
:hidden:`ReplicationPad1d`
212+
~~~~~~~~~~~~~~~~~~~~~~~~~~
213+
214+
.. autoclass:: ReplicationPad1d
215+
:members:
216+
205217
:hidden:`ReplicationPad2d`
206218
~~~~~~~~~~~~~~~~~~~~~~~~~~
207219

@@ -220,12 +232,24 @@ Padding Layers
220232
.. autoclass:: ZeroPad2d
221233
:members:
222234

235+
:hidden:`ConstantPad1d`
236+
~~~~~~~~~~~~~~~~~~~~~~~
237+
238+
.. autoclass:: ConstantPad1d
239+
:members:
240+
223241
:hidden:`ConstantPad2d`
224242
~~~~~~~~~~~~~~~~~~~~~~~
225243

226244
.. autoclass:: ConstantPad2d
227245
:members:
228246

247+
:hidden:`ConstantPad3d`
248+
~~~~~~~~~~~~~~~~~~~~~~~
249+
250+
.. autoclass:: ConstantPad3d
251+
:members:
252+
229253

230254
Non-linear Activations (weighed sum+nonlinearity)
231255
-------------------------------------------------

test/test_nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1923,6 +1923,13 @@ def test_pad(self):
19231923
inputs = Variable(torch.randn(1, 2, 3, 4, 4), requires_grad=True)
19241924
self.assertTrue(gradcheck(lambda x: F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate'), (inputs,)))
19251925

1926+
# assert that relfection padding errors when pad >= input size
1927+
expected_err_msg = r"Padding size should be less than the corresponding input dimension"
1928+
self.assertRaisesRegex(RuntimeError, expected_err_msg,
1929+
lambda: F.pad(torch.randn(1, 1, 2, 3), (1, 1, 3, 0), mode='reflect'))
1930+
self.assertRaisesRegex(RuntimeError, expected_err_msg,
1931+
lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
1932+
19261933
def test_pad_scalar_error(self):
19271934
inputs = torch.tensor(0, requires_grad=True)
19281935
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))

torch/nn/functional.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,10 @@ def pad(input, pad, mode='constant', value=0):
19071907
5D input tensor with padding of the form
19081908
`(padLeft, padRight, padTop, padBottom, padFront, padBack)`. No "reflect" implementation.
19091909
1910+
See :class:`torch.nn.ConstantPad2d`, :class:`torch.nn.ReflectionPad2d`, and
1911+
:class:`torch.nn.ReplicationPad2d` for concrete examples on how each of the
1912+
padding modes works.
1913+
19101914
Args:
19111915
input (Variable): `Nd` tensor
19121916
pad (tuple): m-elem tuple, where :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
@@ -1917,7 +1921,7 @@ def pad(input, pad, mode='constant', value=0):
19171921
19181922
>>> t4d = torch.Tensor(3, 3, 4, 2)
19191923
>>> p1d = (1, 1) # pad last dim by 1 on each side
1920-
>>> out = F.pad(t4d, p1d, "constant", 0)
1924+
>>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding
19211925
>>> print(out.data.size())
19221926
torch.Size([3, 3, 4, 4])
19231927
>>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
@@ -1929,31 +1933,34 @@ def pad(input, pad, mode='constant', value=0):
19291933
>>> out = F.pad(t4d, p3d, "constant", 0)
19301934
>>> print(out.data.size())
19311935
torch.Size([3, 9, 7, 3])
1936+
19321937
"""
19331938
assert len(pad) % 2 == 0, 'Padding length must be divisible by 2'
19341939
assert len(pad) // 2 <= input.dim(), 'Padding length too large'
19351940
if mode == 'constant':
19361941
return ConstantPadNd.apply(input, pad, value)
1937-
elif input.dim() == 3:
1938-
assert len(pad) == 2, '3D tensors expect 2 values for padding'
1939-
if mode == 'reflect':
1940-
return torch._C._nn.reflection_pad1d(input, pad)
1941-
elif mode == 'replicate':
1942-
return torch._C._nn.replication_pad1d(input, pad)
1943-
elif input.dim() == 4:
1944-
assert len(pad) == 4, '4D tensors expect 4 values for padding'
1945-
if mode == 'reflect':
1946-
return torch._C._nn.reflection_pad2d(input, pad)
1947-
elif mode == 'replicate':
1948-
return torch._C._nn.replication_pad2d(input, pad)
1949-
elif input.dim() == 5:
1950-
assert len(pad) == 6, '5D tensors expect 6 values for padding'
1951-
if mode == 'reflect':
1952-
raise NotImplementedError
1953-
elif mode == 'replicate':
1954-
return torch._C._nn.replication_pad3d(input, pad)
19551942
else:
1956-
raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now")
1943+
assert value == 0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode)
1944+
if input.dim() == 3:
1945+
assert len(pad) == 2, '3D tensors expect 2 values for padding'
1946+
if mode == 'reflect':
1947+
return torch._C._nn.reflection_pad1d(input, pad)
1948+
elif mode == 'replicate':
1949+
return torch._C._nn.replication_pad1d(input, pad)
1950+
elif input.dim() == 4:
1951+
assert len(pad) == 4, '4D tensors expect 4 values for padding'
1952+
if mode == 'reflect':
1953+
return torch._C._nn.reflection_pad2d(input, pad)
1954+
elif mode == 'replicate':
1955+
return torch._C._nn.replication_pad2d(input, pad)
1956+
elif input.dim() == 5:
1957+
assert len(pad) == 6, '5D tensors expect 6 values for padding'
1958+
if mode == 'reflect':
1959+
raise NotImplementedError
1960+
elif mode == 'replicate':
1961+
return torch._C._nn.replication_pad3d(input, pad)
1962+
else:
1963+
raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now")
19571964

19581965

19591966
# distance

0 commit comments

Comments
 (0)