Skip to content

Commit 6ed41ad

Browse files
Owen Andersonfacebook-github-bot
authored andcommitted
Use round-to-negative division when computing output sizes for convolutions involving striding and dilation.
Summary: Pull Request resolved: #9640 Differential Revision: D8948081 Pulled By: resistor fbshipit-source-id: 06f2e3ad1bdb448be6f36577cb9bd27c884df595
1 parent 8c0355c commit 6ed41ad

14 files changed

+73
-28
lines changed

aten/src/ATen/div_rtn.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
// Integer division rounding to -Infinity
4+
template<typename T>
5+
static inline T div_rtn(T x, T y) {
6+
int q = x/y;
7+
int r = x%y;
8+
if ((r!=0) && ((r<0) != (y<0))) --q;
9+
return q;
10+
}
11+

aten/src/THCUNN/generic/Col2Im.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define THC_GENERIC_FILE "generic/Col2Im.cu"
33
#else
44

5+
#include <ATen/div_rtn.h>
6+
57
static inline void THNN_(Col2Im_shapeCheck)(
68
THCState *state,
79
THCTensor *input,
@@ -31,8 +33,8 @@ static inline void THNN_(Col2Im_shapeCheck)(
3133
}
3234

3335
int64_t inputLength = input->size(batch_dim + 2);
34-
int64_t nBlocksH = 1 + (outputHeight + 2 * padH - dH * (kH - 1) - 1) / sH;
35-
int64_t nBlocksW = 1 + ( outputWidth + 2 * padW - dW * (kW - 1) - 1) / sW;
36+
int64_t nBlocksH = div_rtn<int64_t>(outputHeight + 2 * padH - dH * (kH - 1) - 1, sH) + 1;
37+
int64_t nBlocksW = div_rtn<int64_t>(outputWidth + 2 * padW - dW * (kW - 1) - 1, sW) + 1;
3638

3739
if (inputLength != (nBlocksH * nBlocksW)) {
3840
THError("Given output_size=(%d, %d), kernel_size=(%d, %d), "

aten/src/THCUNN/generic/Im2Col.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define THC_GENERIC_FILE "generic/Im2Col.cu"
33
#else
44

5+
#include <ATen/div_rtn.h>
6+
57
static inline void THNN_(Im2Col_shapeCheck)(
68
THCState *state,
79
THCTensor *input,
@@ -29,8 +31,8 @@ static inline void THNN_(Im2Col_shapeCheck)(
2931
int64_t nInputPlane = THCTensor_(size)(state, input, dim_batch + 1);
3032
int64_t inputHeight = THCTensor_(size)(state, input, dim_batch + 2);
3133
int64_t inputWidth = THCTensor_(size)(state, input, dim_batch + 3);
32-
int64_t outputHeight = (inputHeight + 2 * padH - (dH * (kH - 1) + 1)) / sH + 1;
33-
int64_t outputWidth = (inputWidth + 2 * padW - (dW * (kW - 1) + 1)) / sW + 1;
34+
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2 * padH - (dH * (kH - 1) + 1), sH) + 1;
35+
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2 * padW - (dW * (kW - 1) + 1), sW) + 1;
3436

3537
if (outputHeight < 1 || outputWidth < 1) {
3638
THError("Given input with spatial size (%d, %d), kernel_size=(%d, %d), "

aten/src/THCUNN/generic/SpatialConvolutionMM.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define THC_GENERIC_FILE "generic/SpatialConvolutionMM.cu"
33
#else
44

5+
#include <ATen/div_rtn.h>
6+
57
static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
68
THCState *state,
79
THCTensor *input, THCTensor *gradOutput,
@@ -49,8 +51,8 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
4951
exactInputHeight, exactInputWidth, kH, kW);
5052
}
5153

52-
int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
53-
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;
54+
int64_t outputHeight = div_rtn<int64_t>(exactInputHeight - kH, dH) + 1;
55+
int64_t outputWidth = div_rtn<int64_t>(exactInputWidth - kW, dW) + 1;
5456

5557
if (outputWidth < 1 || outputHeight < 1) {
5658
THError("Given input size per channel: (%ld x %ld). "

aten/src/THCUNN/generic/SpatialDilatedConvolution.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define THC_GENERIC_FILE "generic/SpatialDilatedConvolution.cu"
33
#else
44

5+
#include <ATen/div_rtn.h>
6+
57
static inline void THNN_(SpatialDilatedConvolution_shapeCheck)(
68
THCState *state,
79
THCTensor *input, THCTensor *gradOutput,
@@ -44,8 +46,8 @@ static inline void THNN_(SpatialDilatedConvolution_shapeCheck)(
4446
int64_t inputHeight = input->size(dimh);
4547
int64_t inputWidth = input->size(dimw);
4648

47-
int64_t outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1;
48-
int64_t outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1;
49+
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2*padH - (dilationH * (kH - 1) + 1), dH) + 1;
50+
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2*padW - (dilationW * (kW - 1) + 1), dW) + 1;
4951

5052
if (outputWidth < 1 || outputHeight < 1) {
5153
THError("Given input size per channel: (%ld x %ld). "

aten/src/THCUNN/generic/VolumetricConvolution.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define THC_GENERIC_FILE "generic/VolumetricConvolution.cu"
33
#else
44

5+
#include <ATen/div_rtn.h>
6+
57
static inline void THNN_(VolumetricConvolution_shapeCheck)
68
(THCState *state,
79
THCTensor *input,
@@ -83,9 +85,9 @@ static inline void THNN_(VolumetricConvolution_shapeCheck)
8385
exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW);
8486
}
8587

86-
int64_t outputWidth = (exactInputDepth - kH) / dH + 1;
87-
int64_t outputHeight = (exactInputHeight - kT) / dT + 1;
88-
int64_t outputDepth = (exactInputWidth - kW) / dW + 1;
88+
int64_t outputWidth = div_rtn<int64_t>(exactInputDepth - kH, dH) + 1;
89+
int64_t outputHeight = div_rtn<int64_t>(exactInputHeight - kT, dT) + 1;
90+
int64_t outputDepth = div_rtn<int64_t>(exactInputWidth - kW, dW) + 1;
8991

9092
if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1)
9193
{

aten/src/THCUNN/generic/VolumetricDilatedConvolution.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define THC_GENERIC_FILE "generic/VolumetricDilatedConvolution.cu"
33
#else
44

5+
#include <ATen/div_rtn.h>
6+
57
static inline void THNN_(VolumetricDilatedConvolution_shapeCheck)(
68
THCState *state,
79
THCTensor *input,
@@ -53,9 +55,9 @@ static inline void THNN_(VolumetricDilatedConvolution_shapeCheck)(
5355
int64_t inputDepth = input->size(dimd);
5456
int64_t inputHeight = input->size(dimh);
5557
int64_t inputWidth = input->size(dimw);
56-
int64_t outputDepth = (inputDepth + 2*padT - (dilationT * (kT - 1) + 1)) / dT + 1;
57-
int64_t outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1;
58-
int64_t outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1;
58+
int64_t outputDepth = div_rtn<int64_t>(inputDepth + 2*padT - (dilationT * (kT - 1) + 1), dT) + 1;
59+
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2*padH - (dilationH * (kH - 1) + 1), dH) + 1;
60+
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2*padW - (dilationW * (kW - 1) + 1), dW) + 1;
5961

6062
if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1) {
6163
THError("Given input size per channel: (%ld x %ld x %ld). "

aten/src/THNN/generic/Col2Im.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define TH_GENERIC_FILE "generic/Col2Im.c"
33
#else
44

5+
#include <ATen/div_rtn.h>
6+
57
// Note [im2col/col2im output padding]
68
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79
// Our implementations of im2col and col2im take both the input height/width as
@@ -138,8 +140,8 @@ static inline void THNN_(Col2Im_shapeCheck)(
138140
}
139141

140142
int64_t inputLength = input->size(batch_dim + 2);
141-
int64_t nBlocksH = 1 + (outputHeight + 2 * padH - dH * (kH - 1) - 1) / sH;
142-
int64_t nBlocksW = 1 + ( outputWidth + 2 * padW - dW * (kW - 1) - 1) / sW;
143+
int64_t nBlocksH = div_rtn<int64_t>(outputHeight + 2 * padH - dH * (kH - 1) - 1, sH) + 1;
144+
int64_t nBlocksW = div_rtn<int64_t>(outputWidth + 2 * padW - dW * (kW - 1) - 1, sW) + 1;
143145

144146
if (inputLength != (nBlocksH * nBlocksW)) {
145147
THError("Given output_size=(%d, %d), kernel_size=(%d, %d), "

aten/src/THNN/generic/Im2Col.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define TH_GENERIC_FILE "generic/Im2Col.c"
33
#else
44

5+
#include <ATen/div_rtn.h>
6+
57
static inline void THNN_(Im2Col_shapeCheck)(
68
THNNState *state,
79
THTensor *input,
@@ -27,8 +29,8 @@ static inline void THNN_(Im2Col_shapeCheck)(
2729
int64_t nInputPlane = THTensor_(size)(input, dim_batch + 1);
2830
int64_t inputHeight = THTensor_(size)(input, dim_batch + 2);
2931
int64_t inputWidth = THTensor_(size)(input, dim_batch + 3);
30-
int64_t outputHeight = (inputHeight + 2 * padH - (dH * (kH - 1) + 1)) / sH + 1;
31-
int64_t outputWidth = (inputWidth + 2 * padW - (dW * (kW - 1) + 1)) / sW + 1;
32+
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2 * padH - (dH * (kH - 1) + 1), sH) + 1;
33+
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2 * padW - (dW * (kW - 1) + 1), sW) + 1;
3234
int64_t nOutputPlane = nInputPlane * kW * kH;
3335
int64_t outputLength = outputHeight * outputWidth;
3436

aten/src/THNN/generic/SpatialConvolutionMM.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define TH_GENERIC_FILE "generic/SpatialConvolutionMM.c"
33
#else
44

5+
#include <ATen/div_rtn.h>
6+
57
static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
68
THTensor *input, THTensor *gradOutput,
79
THTensor *weight, THTensor *bias,
@@ -48,8 +50,8 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
4850
exactInputHeight, exactInputWidth, kH, kW);
4951
}
5052

51-
int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
52-
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;
53+
int64_t outputHeight = div_rtn<int64_t>(exactInputHeight - kH, dH) + 1;
54+
int64_t outputWidth = div_rtn<int64_t>(exactInputWidth - kW, dW) + 1;
5355

5456
if (outputWidth < 1 || outputHeight < 1) {
5557
THError("Given input size per channel: (%ld x %ld). "

0 commit comments

Comments
 (0)