Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions aten/src/ATen/div_rtn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

// Integer division rounding to -Infinity
template<typename T>
static inline T div_rtn(T x, T y) {
int q = x/y;
int r = x%y;
if ((r!=0) && ((r<0) != (y<0))) --q;
return q;
}

6 changes: 4 additions & 2 deletions aten/src/THCUNN/generic/Col2Im.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "generic/Col2Im.cu"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(Col2Im_shapeCheck)(
THCState *state,
THCTensor *input,
Expand Down Expand Up @@ -31,8 +33,8 @@ static inline void THNN_(Col2Im_shapeCheck)(
}

int64_t inputLength = input->size(batch_dim + 2);
int64_t nBlocksH = 1 + (outputHeight + 2 * padH - dH * (kH - 1) - 1) / sH;
int64_t nBlocksW = 1 + ( outputWidth + 2 * padW - dW * (kW - 1) - 1) / sW;
int64_t nBlocksH = div_rtn<int64_t>(outputHeight + 2 * padH - dH * (kH - 1) - 1, sH) + 1;
int64_t nBlocksW = div_rtn<int64_t>(outputWidth + 2 * padW - dW * (kW - 1) - 1, sW) + 1;

if (inputLength != (nBlocksH * nBlocksW)) {
THError("Given output_size=(%d, %d), kernel_size=(%d, %d), "
Expand Down
6 changes: 4 additions & 2 deletions aten/src/THCUNN/generic/Im2Col.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "generic/Im2Col.cu"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(Im2Col_shapeCheck)(
THCState *state,
THCTensor *input,
Expand Down Expand Up @@ -29,8 +31,8 @@ static inline void THNN_(Im2Col_shapeCheck)(
int64_t nInputPlane = THCTensor_(size)(state, input, dim_batch + 1);
int64_t inputHeight = THCTensor_(size)(state, input, dim_batch + 2);
int64_t inputWidth = THCTensor_(size)(state, input, dim_batch + 3);
int64_t outputHeight = (inputHeight + 2 * padH - (dH * (kH - 1) + 1)) / sH + 1;
int64_t outputWidth = (inputWidth + 2 * padW - (dW * (kW - 1) + 1)) / sW + 1;
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2 * padH - (dH * (kH - 1) + 1), sH) + 1;
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2 * padW - (dW * (kW - 1) + 1), sW) + 1;

if (outputHeight < 1 || outputWidth < 1) {
THError("Given input with spatial size (%d, %d), kernel_size=(%d, %d), "
Expand Down
6 changes: 4 additions & 2 deletions aten/src/THCUNN/generic/SpatialConvolutionMM.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "generic/SpatialConvolutionMM.cu"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
THCState *state,
THCTensor *input, THCTensor *gradOutput,
Expand Down Expand Up @@ -49,8 +51,8 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
exactInputHeight, exactInputWidth, kH, kW);
}

int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;
int64_t outputHeight = div_rtn<int64_t>(exactInputHeight - kH, dH) + 1;
int64_t outputWidth = div_rtn<int64_t>(exactInputWidth - kW, dW) + 1;

if (outputWidth < 1 || outputHeight < 1) {
THError("Given input size per channel: (%ld x %ld). "
Expand Down
6 changes: 4 additions & 2 deletions aten/src/THCUNN/generic/SpatialDilatedConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "generic/SpatialDilatedConvolution.cu"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(SpatialDilatedConvolution_shapeCheck)(
THCState *state,
THCTensor *input, THCTensor *gradOutput,
Expand Down Expand Up @@ -44,8 +46,8 @@ static inline void THNN_(SpatialDilatedConvolution_shapeCheck)(
int64_t inputHeight = input->size(dimh);
int64_t inputWidth = input->size(dimw);

int64_t outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1;
int64_t outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1;
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2*padH - (dilationH * (kH - 1) + 1), dH) + 1;
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2*padW - (dilationW * (kW - 1) + 1), dW) + 1;

if (outputWidth < 1 || outputHeight < 1) {
THError("Given input size per channel: (%ld x %ld). "
Expand Down
8 changes: 5 additions & 3 deletions aten/src/THCUNN/generic/VolumetricConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "generic/VolumetricConvolution.cu"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(VolumetricConvolution_shapeCheck)
(THCState *state,
THCTensor *input,
Expand Down Expand Up @@ -83,9 +85,9 @@ static inline void THNN_(VolumetricConvolution_shapeCheck)
exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW);
}

int64_t outputWidth = (exactInputDepth - kH) / dH + 1;
int64_t outputHeight = (exactInputHeight - kT) / dT + 1;
int64_t outputDepth = (exactInputWidth - kW) / dW + 1;
int64_t outputWidth = div_rtn<int64_t>(exactInputDepth - kH, dH) + 1;
int64_t outputHeight = div_rtn<int64_t>(exactInputHeight - kT, dT) + 1;
int64_t outputDepth = div_rtn<int64_t>(exactInputWidth - kW, dW) + 1;

if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1)
{
Expand Down
8 changes: 5 additions & 3 deletions aten/src/THCUNN/generic/VolumetricDilatedConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "generic/VolumetricDilatedConvolution.cu"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(VolumetricDilatedConvolution_shapeCheck)(
THCState *state,
THCTensor *input,
Expand Down Expand Up @@ -53,9 +55,9 @@ static inline void THNN_(VolumetricDilatedConvolution_shapeCheck)(
int64_t inputDepth = input->size(dimd);
int64_t inputHeight = input->size(dimh);
int64_t inputWidth = input->size(dimw);
int64_t outputDepth = (inputDepth + 2*padT - (dilationT * (kT - 1) + 1)) / dT + 1;
int64_t outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1;
int64_t outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1;
int64_t outputDepth = div_rtn<int64_t>(inputDepth + 2*padT - (dilationT * (kT - 1) + 1), dT) + 1;
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2*padH - (dilationH * (kH - 1) + 1), dH) + 1;
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2*padW - (dilationW * (kW - 1) + 1), dW) + 1;

if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1) {
THError("Given input size per channel: (%ld x %ld x %ld). "
Expand Down
6 changes: 4 additions & 2 deletions aten/src/THNN/generic/Col2Im.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/Col2Im.c"
#else

#include <ATen/div_rtn.h>

// Note [im2col/col2im output padding]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Our implementations of im2col and col2im take both the input height/width as
Expand Down Expand Up @@ -138,8 +140,8 @@ static inline void THNN_(Col2Im_shapeCheck)(
}

int64_t inputLength = input->size(batch_dim + 2);
int64_t nBlocksH = 1 + (outputHeight + 2 * padH - dH * (kH - 1) - 1) / sH;
int64_t nBlocksW = 1 + ( outputWidth + 2 * padW - dW * (kW - 1) - 1) / sW;
int64_t nBlocksH = div_rtn<int64_t>(outputHeight + 2 * padH - dH * (kH - 1) - 1, sH) + 1;
int64_t nBlocksW = div_rtn<int64_t>(outputWidth + 2 * padW - dW * (kW - 1) - 1, sW) + 1;

if (inputLength != (nBlocksH * nBlocksW)) {
THError("Given output_size=(%d, %d), kernel_size=(%d, %d), "
Expand Down
6 changes: 4 additions & 2 deletions aten/src/THNN/generic/Im2Col.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/Im2Col.c"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(Im2Col_shapeCheck)(
THNNState *state,
THTensor *input,
Expand All @@ -27,8 +29,8 @@ static inline void THNN_(Im2Col_shapeCheck)(
int64_t nInputPlane = THTensor_(size)(input, dim_batch + 1);
int64_t inputHeight = THTensor_(size)(input, dim_batch + 2);
int64_t inputWidth = THTensor_(size)(input, dim_batch + 3);
int64_t outputHeight = (inputHeight + 2 * padH - (dH * (kH - 1) + 1)) / sH + 1;
int64_t outputWidth = (inputWidth + 2 * padW - (dW * (kW - 1) + 1)) / sW + 1;
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2 * padH - (dH * (kH - 1) + 1), sH) + 1;
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2 * padW - (dW * (kW - 1) + 1), sW) + 1;
int64_t nOutputPlane = nInputPlane * kW * kH;
int64_t outputLength = outputHeight * outputWidth;

Expand Down
6 changes: 4 additions & 2 deletions aten/src/THNN/generic/SpatialConvolutionMM.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/SpatialConvolutionMM.c"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
THTensor *input, THTensor *gradOutput,
THTensor *weight, THTensor *bias,
Expand Down Expand Up @@ -48,8 +50,8 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)(
exactInputHeight, exactInputWidth, kH, kW);
}

int64_t outputHeight = (exactInputHeight - kH) / dH + 1;
int64_t outputWidth = (exactInputWidth - kW) / dW + 1;
int64_t outputHeight = div_rtn<int64_t>(exactInputHeight - kH, dH) + 1;
int64_t outputWidth = div_rtn<int64_t>(exactInputWidth - kW, dW) + 1;

if (outputWidth < 1 || outputHeight < 1) {
THError("Given input size per channel: (%ld x %ld). "
Expand Down
6 changes: 4 additions & 2 deletions aten/src/THNN/generic/SpatialDilatedConvolution.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/SpatialDilatedConvolution.c"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(SpatialDilatedConvolution_shapeCheck)(
THTensor *input, THTensor *gradOutput,
THTensor *weight, THTensor *bias,
Expand Down Expand Up @@ -43,8 +45,8 @@ static inline void THNN_(SpatialDilatedConvolution_shapeCheck)(
int64_t inputHeight = input->size(dimh);
int64_t inputWidth = input->size(dimw);

int64_t outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1;
int64_t outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1;
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2*padH - (dilationH * (kH - 1) + 1), dH) + 1;
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2*padW - (dilationW * (kW - 1) + 1), dW) + 1;

if (outputWidth < 1 || outputHeight < 1) {
THError("Given input size per channel: (%ld x %ld). "
Expand Down
8 changes: 5 additions & 3 deletions aten/src/THNN/generic/VolumetricConvolutionMM.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/VolumetricConvolutionMM.c"
#else

#include <ATen/div_rtn.h>

#define CONV3D_OMP_THRESHOLD 20

static void inline THNN_(VolumetricConvolutionMM_shapeCheck)(
Expand Down Expand Up @@ -76,9 +78,9 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)(
exactInputDepth, exactInputHeight, exactInputWidth, kT, kH, kW);
}

outputDepth = (exactInputDepth - kT) / dT + 1;
outputHeight = (exactInputHeight - kH) / dH + 1;
outputWidth = (exactInputWidth - kW) / dW + 1;
outputDepth = div_rtn<int64_t>(exactInputDepth - kT, dT) + 1;
outputHeight = div_rtn<int64_t>(exactInputHeight - kH, dH) + 1;
outputWidth = div_rtn<int64_t>(exactInputWidth - kW, dW) + 1;


if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1) {
Expand Down
8 changes: 5 additions & 3 deletions aten/src/THNN/generic/VolumetricDilatedConvolution.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/VolumetricDilatedConvolution.c"
#else

#include <ATen/div_rtn.h>

static inline void THNN_(VolumetricDilatedConvolution_shapeCheck)(
THTensor *input, THTensor *gradOutput,
THTensor *weight, THTensor *bias,
Expand Down Expand Up @@ -47,9 +49,9 @@ static inline void THNN_(VolumetricDilatedConvolution_shapeCheck)(
int64_t inputDepth = input->size(dimd);
int64_t inputHeight = input->size(dimh);
int64_t inputWidth = input->size(dimw);
int64_t outputDepth = (inputDepth + 2*padT - (dilationT * (kT - 1) + 1)) / dT + 1;
int64_t outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1;
int64_t outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1;
int64_t outputDepth = div_rtn<int64_t>(inputDepth + 2*padT - (dilationT * (kT - 1) + 1), dT) + 1;
int64_t outputHeight = div_rtn<int64_t>(inputHeight + 2*padH - (dilationH * (kH - 1) + 1), dH) + 1;
int64_t outputWidth = div_rtn<int64_t>(inputWidth + 2*padW - (dilationW * (kW - 1) + 1), dW) + 1;

if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1) {
THError("Given input size per channel: (%ld x %ld x %ld). "
Expand Down
10 changes: 10 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,16 @@ def test_no_grad(self):
self.assertFalse(output2.requires_grad)
self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))

def test_invalid_conv2d(self):
module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2)
input = torch.empty(1, 1, 4, 4)
self.assertRaises(RuntimeError, lambda: module(input))

def test_invalid_conv3d(self):
module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2)
input = torch.empty(1, 1, 4, 4, 4)
self.assertRaises(RuntimeError, lambda: module(input))

def _test_dropout(self, cls, input):
p = 0.2
input.fill_(1 - p)
Expand Down