Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions aten/src/TH/generic/THTensorMath.c
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,10 @@ void THTensor_(fmod)(THTensor *r_, THTensor *t, real value)
}
}

static inline bool has_different_sign(real a, real b) {
return (a < 0) != (b < 0);
}

void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
{
THTensor_(resizeAs)(r_, t);
Expand All @@ -1070,7 +1074,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
#else
// There is no NAN for integers
rp[i] = tp[i] % value;
if (rp[i] * value < 0)
if (has_different_sign(rp[i], value))
rp[i] += value;
#endif
}
Expand All @@ -1085,7 +1089,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
#else
// There is no NAN for integers
TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = *t_data % value;
if (*r__data * value < 0) *r__data += value;);
if (has_different_sign(*r__data, value)) *r__data += value;);
#endif
}
#else
Expand All @@ -1098,7 +1102,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
#else
// There is no NAN for integers
TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data % value;
if (*r__data * value < 0) *r__data += value;);
if (has_different_sign(*r__data, value)) *r__data += value;);
#endif
}
}
Expand Down
5 changes: 5 additions & 0 deletions aten/src/THC/THCNumerics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ static inline __host__ __device__ scalar_t powi(scalar_t a, scalar_t b) {
return result;
}

template <typename scalar_t>
static inline __host__ __device__ bool has_different_sign(scalar_t a, scalar_t b) {
return (a < 0) != (b < 0);
}

template <>
struct THCNumerics<uint8_t> {
static inline __host__ __device__ uint8_t min() { return 0; }
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THC/THCTensorMathPairwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,14 @@ struct TensorRemainderOp {
TensorRemainderOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = *in % val;
if ((*out * val) < 0){
if (has_different_sign<T>(*out, val)){
*out += val;
}
}

__device__ __forceinline__ void operator()(T* v) {
*v = *v % val;
if ((*v * val) < 0){
if (has_different_sign<T>(*v, val)){
*v += val;
}
}
Expand Down
3 changes: 3 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,9 @@ def test_min_max_inits(self):
def test_int_pow(self):
TestTorch._test_int_pow(self, lambda x: x.cuda())

def test_remainder_overflow(self):
TestTorch._test_remainder_overflow(self, dtype=torch.cuda.int64)

def test_var(self):
cpu_tensor = torch.randn(2, 3, 3)
gpu_tensor = cpu_tensor.cuda()
Expand Down
13 changes: 13 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,19 @@ def test_remainder(self):
long_res1 = long_m1.clone()
long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1))

@staticmethod
def _test_remainder_overflow(self, dtype=torch.int64):
# Check Integer Overflows
x = torch.tensor(23500, dtype=dtype)
q = 392486996410368
self.assertEqual(x % q, x)
self.assertEqual(-x % q, q - x)
self.assertEqual(x % -q, x - q)
self.assertEqual(-x % -q, -x)

def test_remainder_overflow(self):
self._test_remainder_overflow(self, dtype=torch.int64)

def test_mm(self):
# helper function
def matrixmultiply(mat1, mat2):
Expand Down