Skip to content

Commit 08b1324

Browse files
vedanujsoumith
authored andcommitted
Fix integer overflow in remainder operator (#5906)
* Fix integer overflow in remainder * Fix remainder operator in CUDA * Add tests for remainder integer overflow * Add has_different_sign static function
1 parent 06e86a6 commit 08b1324

File tree

5 files changed

+30
-5
lines changed

5 files changed

+30
-5
lines changed

aten/src/TH/generic/THTensorMath.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,10 @@ void THTensor_(fmod)(THTensor *r_, THTensor *t, real value)
10521052
}
10531053
}
10541054

1055+
static inline bool has_different_sign(real a, real b) {
1056+
return (a < 0) != (b < 0);
1057+
}
1058+
10551059
void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
10561060
{
10571061
THTensor_(resizeAs)(r_, t);
@@ -1070,7 +1074,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
10701074
#else
10711075
// There is no NAN for integers
10721076
rp[i] = tp[i] % value;
1073-
if (rp[i] * value < 0)
1077+
if (has_different_sign(rp[i], value))
10741078
rp[i] += value;
10751079
#endif
10761080
}
@@ -1085,7 +1089,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
10851089
#else
10861090
// There is no NAN for integers
10871091
TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = *t_data % value;
1088-
if (*r__data * value < 0) *r__data += value;);
1092+
if (has_different_sign(*r__data, value)) *r__data += value;);
10891093
#endif
10901094
}
10911095
#else
@@ -1098,7 +1102,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
10981102
#else
10991103
// There is no NAN for integers
11001104
TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data % value;
1101-
if (*r__data * value < 0) *r__data += value;);
1105+
if (has_different_sign(*r__data, value)) *r__data += value;);
11021106
#endif
11031107
}
11041108
}

aten/src/THC/THCNumerics.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ static inline __host__ __device__ scalar_t powi(scalar_t a, scalar_t b) {
2828
return result;
2929
}
3030

31+
template <typename scalar_t>
32+
static inline __host__ __device__ bool has_different_sign(scalar_t a, scalar_t b) {
33+
return (a < 0) != (b < 0);
34+
}
35+
3136
template <>
3237
struct THCNumerics<uint8_t> {
3338
static inline __host__ __device__ uint8_t min() { return 0; }

aten/src/THC/THCTensorMathPairwise.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,14 @@ struct TensorRemainderOp {
245245
TensorRemainderOp(T v) : val(v) {}
246246
__device__ __forceinline__ void operator()(T* out, T* in) {
247247
*out = *in % val;
248-
if ((*out * val) < 0){
248+
if (has_different_sign<T>(*out, val)){
249249
*out += val;
250250
}
251251
}
252252

253253
__device__ __forceinline__ void operator()(T* v) {
254254
*v = *v % val;
255-
if ((*v * val) < 0){
255+
if (has_different_sign<T>(*v, val)){
256256
*v += val;
257257
}
258258
}

test/test_cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,9 @@ def test_min_max_inits(self):
14341434
def test_int_pow(self):
14351435
TestTorch._test_int_pow(self, lambda x: x.cuda())
14361436

1437+
def test_remainder_overflow(self):
1438+
TestTorch._test_remainder_overflow(self, dtype=torch.cuda.int64)
1439+
14371440
def test_var(self):
14381441
cpu_tensor = torch.randn(2, 3, 3)
14391442
gpu_tensor = cpu_tensor.cuda()

test/test_torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,19 @@ def test_remainder(self):
780780
long_res1 = long_m1.clone()
781781
long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1))
782782

783+
@staticmethod
784+
def _test_remainder_overflow(self, dtype=torch.int64):
785+
# Check Integer Overflows
786+
x = torch.tensor(23500, dtype=dtype)
787+
q = 392486996410368
788+
self.assertEqual(x % q, x)
789+
self.assertEqual(-x % q, q - x)
790+
self.assertEqual(x % -q, x - q)
791+
self.assertEqual(-x % -q, -x)
792+
793+
def test_remainder_overflow(self):
794+
self._test_remainder_overflow(self, dtype=torch.int64)
795+
783796
def test_mm(self):
784797
# helper function
785798
def matrixmultiply(mat1, mat2):

0 commit comments

Comments
 (0)