Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions aten/src/THCUNN/generic/SpatialConvolutionMM.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,18 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
int64_t inputHeight = input->size[dimh];
int64_t inputWidth = input->size[dimw];
int64_t nOutputPlane = weight->size[0];
int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;

int64_t exactInputHeight = inputHeight + 2 * padH;
int64_t exactInputWidth = inputWidth + 2 * padW;

if (exactInputHeight < kH || exactInputWidth < kW) {
THError("Calculated input size: (%d x %d). "
"Kernel size: (%d x %d). Kernel size can't greater than actual input size",
exactInputHeight,exactInputWidth,kH,kW);
}

int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;

if (outputWidth < 1 || outputHeight < 1)
THError("Given input size: (%d x %d x %d). "
Expand Down
17 changes: 14 additions & 3 deletions aten/src/THCUNN/generic/VolumetricConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,20 @@ static inline void THNN_(VolumetricConvolution_shapeCheck)
int64_t inputWidth = input->size[dimw];
int64_t inputHeight = input->size[dimh];
int64_t inputDepth = input->size[dimd];
int64_t outputWidth = (inputWidth + 2*padH - kH) / dH + 1;
int64_t outputHeight = (inputHeight + 2*padT - kT) / dT + 1;
int64_t outputDepth = (inputDepth + 2*padW - kW) / dW + 1;

int64_t exactInputDepth = inputDepth + 2*padT;
int64_t exactInputHeight = inputHeight + 2*padH;
int64_t exactInputWidth = inputWidth + 2*padW;

if (exactInputDepth < kT || exactInputHeight < kH || exactInputWidth < kW) {
THError("Calculated input size: (%d x %d x %d). "
"Kernel size: (%d x %d x %d). Kernel size can't greater than actual input size",
exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW);
}

int64_t outputWidth = (exactInputDepth - kH) / dH + 1;
int64_t outputHeight = (exactInputHeight - kT) / dT + 1;
int64_t outputDepth = (exactInputWidth - kW) / dW + 1;

if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1)
{
Expand Down
17 changes: 14 additions & 3 deletions aten/src/THNN/generic/SpatialConvolutionMM.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,24 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
int64_t inputHeight = input->size[dimh];
int64_t inputWidth = input->size[dimw];
int64_t nOutputPlane = weight->size[0];
int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;

if (outputWidth < 1 || outputHeight < 1)
int64_t exactInputHeight = inputHeight + 2 * padH;
int64_t exactInputWidth = inputWidth + 2 * padW;

if (exactInputHeight < kH || exactInputWidth < kW) {
THError("Calculated input size: (%d x %d). "
"Kernel size: (%d x %d). Kernel size can't greater than actual input size",
exactInputHeight,exactInputWidth,kH,kW);
}

int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;

if (outputWidth < 1 || outputHeight < 1) {
THError("Given input size: (%d x %d x %d). "
"Calculated output size: (%d x %d x %d). Output size is too small",
nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
}

THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane);

Expand Down
21 changes: 18 additions & 3 deletions aten/src/THNN/generic/VolumetricConvolutionMM.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)(
int64_t inputHeight;
int64_t inputWidth;
int64_t nOutputPlane;

int64_t exactInputDepth;
int64_t exactInputHeight;
int64_t exactInputWidth;
int64_t outputDepth;
int64_t outputHeight;
int64_t outputWidth;
Expand All @@ -52,9 +56,20 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)(
inputHeight = input->size[dimh];
inputWidth = input->size[dimw];
nOutputPlane = weight->size[0];
outputDepth = (inputDepth + 2*pT - kT) / dT + 1;
outputHeight = (inputHeight + 2*pH - kH) / dH + 1;
outputWidth = (inputWidth + 2*pW - kW) / dW + 1;

exactInputDepth = inputDepth + 2*pT;
exactInputHeight = inputHeight + 2*pH;
exactInputWidth = inputWidth + 2*pW;

if (exactInputDepth < kT || exactInputHeight < kH || exactInputWidth < kW) {
THError("Calculated input size: (%d x %d x %d). "
"Kernel size: (%d x %d x %d). Kernel size can't greater than actual input size",
exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW);
}

outputDepth = (exactInputDepth - kT) / dT + 1;
outputHeight = (exactInputHeight - kH) / dH + 1;
outputWidth = (exactInputWidth - kW) / dW + 1;

if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1)
{
Expand Down
26 changes: 26 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,32 @@ def test_conv_modules_raise_error_on_incorrect_input_size(self):
input = Variable(torch.Tensor(torch.Size((3, ) * dims)))
self.assertRaises(RuntimeError, lambda: module(input))

def test_conv_shapecheck(self):
def test(should_raise, module, input_size):
input = Variable(torch.Tensor(3, *input_size))
if should_raise:
self.assertRaises(RuntimeError, lambda: module(input))
else:
# just run it to ensure no exception raised.
module(input)

# Conv1d
test(True, nn.Conv1d(1, 1, 3), (1, 2))
test(True, nn.Conv1d(1, 1, 3, stride=2), (1, 2))
test(False, nn.Conv1d(1, 1, 2), (1, 2))
test(False, nn.Conv1d(1, 1, 2, stride=2), (1, 2))
test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1), (1, 2))

# Conv2d
test(True, nn.Conv2d(1, 1, (3, 3)), (1, 2, 2))
test(False, nn.Conv2d(1, 1, (3, 3)), (1, 3, 3))
test(False, nn.Conv2d(1, 1, (3, 3), padding=1), (1, 2, 2))

# Conv3D
test(True, nn.Conv3d(1, 1, (3, 3, 3)), (1, 2, 2, 2))
test(False, nn.Conv3d(1, 1, (3, 3, 3)), (1, 3, 3, 3))
test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1), (1, 2, 2, 2))

def test_ConvTranspose2d_output_size(self):
m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
i = Variable(torch.randn(2, 3, 6, 6))
Expand Down