Skip to content

Commit 9e3bcf4

Browse files
sighingnowsoumith
authored andcommitted
More strict shape check on Conv operators. (#4637)
* More strict shape check on Conv operators. Signed-off-by: HE, Tao <sighingnow@gmail.com> * Test case for conv's shape check. Signed-off-by: HE, Tao <sighingnow@gmail.com> * Fix lint. Signed-off-by: HE, Tao <sighingnow@gmail.com>
1 parent b4862f6 commit 9e3bcf4

File tree

5 files changed

+84
-11
lines changed

5 files changed

+84
-11
lines changed

test/test_nn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,6 +1969,32 @@ def test_conv_modules_raise_error_on_incorrect_input_size(self):
19691969
input = Variable(torch.Tensor(torch.Size((3, ) * dims)))
19701970
self.assertRaises(ValueError, lambda: module(input))
19711971

1972+
def test_conv_shapecheck(self):
1973+
def test(should_raise, module, input_size):
1974+
input = Variable(torch.Tensor(3, *input_size))
1975+
if should_raise:
1976+
self.assertRaises(RuntimeError, lambda: module(input))
1977+
else:
1978+
# just run it to ensure no exception raised.
1979+
module(input)
1980+
1981+
# Conv1d
1982+
test(True, nn.Conv1d(1, 1, 3), (1, 2))
1983+
test(True, nn.Conv1d(1, 1, 3, stride=2), (1, 2))
1984+
test(False, nn.Conv1d(1, 1, 2), (1, 2))
1985+
test(False, nn.Conv1d(1, 1, 2, stride=2), (1, 2))
1986+
test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1), (1, 2))
1987+
1988+
# Conv2d
1989+
test(True, nn.Conv2d(1, 1, (3, 3)), (1, 2, 2))
1990+
test(False, nn.Conv2d(1, 1, (3, 3)), (1, 3, 3))
1991+
test(False, nn.Conv2d(1, 1, (3, 3), padding=1), (1, 2, 2))
1992+
1993+
# Conv3D
1994+
test(True, nn.Conv3d(1, 1, (3, 3, 3)), (1, 2, 2, 2))
1995+
test(False, nn.Conv3d(1, 1, (3, 3, 3)), (1, 3, 3, 3))
1996+
test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1), (1, 2, 2, 2))
1997+
19721998
def test_ConvTranspose2d_output_size(self):
19731999
m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
19742000
i = Variable(torch.randn(2, 3, 6, 6))

torch/lib/THCUNN/generic/SpatialConvolutionMM.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,18 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
3838
int64_t inputHeight = input->size[dimh];
3939
int64_t inputWidth = input->size[dimw];
4040
int64_t nOutputPlane = weight->size[0];
41-
int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
42-
int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
41+
42+
int64_t exactInputHeight = inputHeight + 2 * padH;
43+
int64_t exactInputWidth = inputWidth + 2 * padW;
44+
45+
if (exactInputHeight < kH || exactInputWidth < kW) {
46+
THError("Calculated input size: (%d x %d). "
47+
"Kernel size: (%d x %d). Kernel size can't greater than actual input size",
48+
exactInputHeight,exactInputWidth,kH,kW);
49+
}
50+
51+
int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
52+
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;
4353

4454
if (outputWidth < 1 || outputHeight < 1)
4555
THError("Given input size: (%d x %d x %d). "

torch/lib/THCUNN/generic/VolumetricConvolution.cu

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,20 @@ static inline void THNN_(VolumetricConvolution_shapeCheck)
7272
int64_t inputWidth = input->size[dimw];
7373
int64_t inputHeight = input->size[dimh];
7474
int64_t inputDepth = input->size[dimd];
75-
int64_t outputWidth = (inputWidth + 2*padH - kH) / dH + 1;
76-
int64_t outputHeight = (inputHeight + 2*padT - kT) / dT + 1;
77-
int64_t outputDepth = (inputDepth + 2*padW - kW) / dW + 1;
75+
76+
int64_t exactInputDepth = inputDepth + 2*padT;
77+
int64_t exactInputHeight = inputHeight + 2*padH;
78+
int64_t exactInputWidth = inputWidth + 2*padW;
79+
80+
if (exactInputDepth < kT || exactInputHeight < kH || exactInputWidth < kW) {
81+
THError("Calculated input size: (%d x %d x %d). "
82+
"Kernel size: (%d x %d x %d). Kernel size can't greater than actual input size",
83+
exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW);
84+
}
85+
86+
int64_t outputWidth = (exactInputDepth - kH) / dH + 1;
87+
int64_t outputHeight = (exactInputHeight - kT) / dT + 1;
88+
int64_t outputDepth = (exactInputWidth - kW) / dW + 1;
7889

7990
if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1)
8091
{

torch/lib/THNN/generic/SpatialConvolutionMM.c

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,24 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
3636
int64_t inputHeight = input->size[dimh];
3737
int64_t inputWidth = input->size[dimw];
3838
int64_t nOutputPlane = weight->size[0];
39-
int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
40-
int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
4139

42-
if (outputWidth < 1 || outputHeight < 1)
40+
int64_t exactInputHeight = inputHeight + 2 * padH;
41+
int64_t exactInputWidth = inputWidth + 2 * padW;
42+
43+
if (exactInputHeight < kH || exactInputWidth < kW) {
44+
THError("Calculated input size: (%d x %d). "
45+
"Kernel size: (%d x %d). Kernel size can't greater than actual input size",
46+
exactInputHeight,exactInputWidth,kH,kW);
47+
}
48+
49+
int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
50+
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;
51+
52+
if (outputWidth < 1 || outputHeight < 1) {
4353
THError("Given input size: (%d x %d x %d). "
4454
"Calculated output size: (%d x %d x %d). Output size is too small",
4555
nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
56+
}
4657

4758
THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane);
4859

torch/lib/THNN/generic/VolumetricConvolutionMM.c

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)(
4343
int64_t inputHeight;
4444
int64_t inputWidth;
4545
int64_t nOutputPlane;
46+
47+
int64_t exactInputDepth;
48+
int64_t exactInputHeight;
49+
int64_t exactInputWidth;
4650
int64_t outputDepth;
4751
int64_t outputHeight;
4852
int64_t outputWidth;
@@ -52,9 +56,20 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)(
5256
inputHeight = input->size[dimh];
5357
inputWidth = input->size[dimw];
5458
nOutputPlane = weight->size[0];
55-
outputDepth = (inputDepth + 2*pT - kT) / dT + 1;
56-
outputHeight = (inputHeight + 2*pH - kH) / dH + 1;
57-
outputWidth = (inputWidth + 2*pW - kW) / dW + 1;
59+
60+
exactInputDepth = inputDepth + 2*pT;
61+
exactInputHeight = inputHeight + 2*pH;
62+
exactInputWidth = inputWidth + 2*pW;
63+
64+
if (exactInputDepth < kT || exactInputHeight < kH || exactInputWidth < kW) {
65+
THError("Calculated input size: (%d x %d x %d). "
66+
"Kernel size: (%d x %d x %d). Kernel size can't greater than actual input size",
67+
exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW);
68+
}
69+
70+
outputDepth = (exactInputDepth - kT) / dT + 1;
71+
outputHeight = (exactInputHeight - kH) / dH + 1;
72+
outputWidth = (exactInputWidth - kW) / dW + 1;
5873

5974
if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1)
6075
{

0 commit comments

Comments
 (0)