Skip to content

Commit 5343b71

Browse files
sighingnowsoumith
authored andcommitted
More strict shape check on Conv operators. (#4637)
* More strict shape check on Conv operators. Signed-off-by: HE, Tao <sighingnow@gmail.com> * Test case for conv's shape check. Signed-off-by: HE, Tao <sighingnow@gmail.com> * Fix lint. Signed-off-by: HE, Tao <sighingnow@gmail.com>
1 parent 841ce42 commit 5343b71

File tree

5 files changed

+84
-11
lines changed

5 files changed

+84
-11
lines changed

aten/src/THCUNN/generic/SpatialConvolutionMM.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,18 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
3838
int64_t inputHeight = input->size[dimh];
3939
int64_t inputWidth = input->size[dimw];
4040
int64_t nOutputPlane = weight->size[0];
41-
int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
42-
int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
41+
42+
int64_t exactInputHeight = inputHeight + 2 * padH;
43+
int64_t exactInputWidth = inputWidth + 2 * padW;
44+
45+
if (exactInputHeight < kH || exactInputWidth < kW) {
46+
THError("Calculated input size: (%d x %d). "
47+
"Kernel size: (%d x %d). Kernel size can't greater than actual input size",
48+
exactInputHeight,exactInputWidth,kH,kW);
49+
}
50+
51+
int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
52+
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;
4353

4454
if (outputWidth < 1 || outputHeight < 1)
4555
THError("Given input size: (%d x %d x %d). "

aten/src/THCUNN/generic/VolumetricConvolution.cu

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,20 @@ static inline void THNN_(VolumetricConvolution_shapeCheck)
7272
int64_t inputWidth = input->size[dimw];
7373
int64_t inputHeight = input->size[dimh];
7474
int64_t inputDepth = input->size[dimd];
75-
int64_t outputWidth = (inputWidth + 2*padH - kH) / dH + 1;
76-
int64_t outputHeight = (inputHeight + 2*padT - kT) / dT + 1;
77-
int64_t outputDepth = (inputDepth + 2*padW - kW) / dW + 1;
75+
76+
int64_t exactInputDepth = inputDepth + 2*padT;
77+
int64_t exactInputHeight = inputHeight + 2*padH;
78+
int64_t exactInputWidth = inputWidth + 2*padW;
79+
80+
if (exactInputDepth < kT || exactInputHeight < kH || exactInputWidth < kW) {
81+
THError("Calculated input size: (%d x %d x %d). "
82+
"Kernel size: (%d x %d x %d). Kernel size can't greater than actual input size",
83+
exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW);
84+
}
85+
86+
int64_t outputWidth = (exactInputDepth - kH) / dH + 1;
87+
int64_t outputHeight = (exactInputHeight - kT) / dT + 1;
88+
int64_t outputDepth = (exactInputWidth - kW) / dW + 1;
7889

7990
if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1)
8091
{

aten/src/THNN/generic/SpatialConvolutionMM.c

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,24 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
3636
int64_t inputHeight = input->size[dimh];
3737
int64_t inputWidth = input->size[dimw];
3838
int64_t nOutputPlane = weight->size[0];
39-
int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
40-
int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
4139

42-
if (outputWidth < 1 || outputHeight < 1)
40+
int64_t exactInputHeight = inputHeight + 2 * padH;
41+
int64_t exactInputWidth = inputWidth + 2 * padW;
42+
43+
if (exactInputHeight < kH || exactInputWidth < kW) {
44+
THError("Calculated input size: (%d x %d). "
45+
"Kernel size: (%d x %d). Kernel size can't greater than actual input size",
46+
exactInputHeight,exactInputWidth,kH,kW);
47+
}
48+
49+
int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
50+
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;
51+
52+
if (outputWidth < 1 || outputHeight < 1) {
4353
THError("Given input size: (%d x %d x %d). "
4454
"Calculated output size: (%d x %d x %d). Output size is too small",
4555
nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
56+
}
4657

4758
THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane);
4859

aten/src/THNN/generic/VolumetricConvolutionMM.c

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)(
4343
int64_t inputHeight;
4444
int64_t inputWidth;
4545
int64_t nOutputPlane;
46+
47+
int64_t exactInputDepth;
48+
int64_t exactInputHeight;
49+
int64_t exactInputWidth;
4650
int64_t outputDepth;
4751
int64_t outputHeight;
4852
int64_t outputWidth;
@@ -52,9 +56,20 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)(
5256
inputHeight = input->size[dimh];
5357
inputWidth = input->size[dimw];
5458
nOutputPlane = weight->size[0];
55-
outputDepth = (inputDepth + 2*pT - kT) / dT + 1;
56-
outputHeight = (inputHeight + 2*pH - kH) / dH + 1;
57-
outputWidth = (inputWidth + 2*pW - kW) / dW + 1;
59+
60+
exactInputDepth = inputDepth + 2*pT;
61+
exactInputHeight = inputHeight + 2*pH;
62+
exactInputWidth = inputWidth + 2*pW;
63+
64+
if (exactInputDepth < kT || exactInputHeight < kH || exactInputWidth < kW) {
65+
THError("Calculated input size: (%d x %d x %d). "
66+
"Kernel size: (%d x %d x %d). Kernel size can't greater than actual input size",
67+
exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW);
68+
}
69+
70+
outputDepth = (exactInputDepth - kT) / dT + 1;
71+
outputHeight = (exactInputHeight - kH) / dH + 1;
72+
outputWidth = (exactInputWidth - kW) / dW + 1;
5873

5974
if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1)
6075
{

test/test_nn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2202,6 +2202,32 @@ def test_conv_modules_raise_error_on_incorrect_input_size(self):
22022202
input = Variable(torch.Tensor(torch.Size((3, ) * dims)))
22032203
self.assertRaises(RuntimeError, lambda: module(input))
22042204

2205+
def test_conv_shapecheck(self):
2206+
def test(should_raise, module, input_size):
2207+
input = Variable(torch.Tensor(3, *input_size))
2208+
if should_raise:
2209+
self.assertRaises(RuntimeError, lambda: module(input))
2210+
else:
2211+
# just run it to ensure no exception raised.
2212+
module(input)
2213+
2214+
# Conv1d
2215+
test(True, nn.Conv1d(1, 1, 3), (1, 2))
2216+
test(True, nn.Conv1d(1, 1, 3, stride=2), (1, 2))
2217+
test(False, nn.Conv1d(1, 1, 2), (1, 2))
2218+
test(False, nn.Conv1d(1, 1, 2, stride=2), (1, 2))
2219+
test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1), (1, 2))
2220+
2221+
# Conv2d
2222+
test(True, nn.Conv2d(1, 1, (3, 3)), (1, 2, 2))
2223+
test(False, nn.Conv2d(1, 1, (3, 3)), (1, 3, 3))
2224+
test(False, nn.Conv2d(1, 1, (3, 3), padding=1), (1, 2, 2))
2225+
2226+
# Conv3D
2227+
test(True, nn.Conv3d(1, 1, (3, 3, 3)), (1, 2, 2, 2))
2228+
test(False, nn.Conv3d(1, 1, (3, 3, 3)), (1, 3, 3, 3))
2229+
test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1), (1, 2, 2, 2))
2230+
22052231
def test_ConvTranspose2d_output_size(self):
22062232
m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
22072233
i = Variable(torch.randn(2, 3, 6, 6))

0 commit comments

Comments
 (0)