Skip to content

Commit bb097e2

Browse files
albanDsoumith
authored andcommitted
[pytorch] Fix signed random_ (#6463)
* Fix cpu signed random * fix gpu signed tensor * add test for signed random_ * cleaner tests * fix lint
1 parent f41044f commit bb097e2

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

aten/src/TH/generic/THTensorRandom.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ void THTensor_(clampedRandom)(THTensor *self, THGenerator *_generator, int64_t m
3333
uint64_t range = max - min;
3434
#if defined(TH_REAL_IS_LONG) || defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
3535
if (range >= 1ULL << 32) {
36-
TH_TENSOR_APPLY(real, self, *self_data = (real)((THRandom_random64(_generator) % range) + min);)
36+
TH_TENSOR_APPLY(real, self, *self_data = static_cast<real>(static_cast<int64_t>((THRandom_random64(_generator) % range) + min));)
3737
return;
3838
}
3939
#endif
40-
TH_TENSOR_APPLY(real, self, *self_data = (real)((THRandom_random(_generator) % range) + min);)
40+
TH_TENSOR_APPLY(real, self, *self_data = static_cast<real>(static_cast<int64_t>((THRandom_random(_generator) % range) + min));)
4141
}
4242

4343
void THTensor_(cappedRandom)(THTensor *self, THGenerator *_generator, int64_t max) {

aten/src/THC/generic/THCTensorRandom.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,13 +459,17 @@ GENERATE_KERNEL1(generate_geometric, real, double p, float, curand_uniform, (Sca
459459
#endif
460460
461461
#if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT)
462-
#define CURAND64(STATE) (((uint64_t)curand(&state[blockIdx.x])) << 32) | (uint64_t)curand(&state[blockIdx.x])
463-
GENERATE_KERNEL2(generate_random, real, int32_t base, uint32_t range, uint32_t, curand, (real)(x % range + base))
464-
GENERATE_KERNEL2(generate_random_64, real, int64_t base, uint64_t range, uint64_t, CURAND64, (real)(x % range + base))
462+
#define CURAND64(STATE) (((uint64_t)curand(STATE)) << 32) | (uint64_t)curand(STATE)
463+
GENERATE_KERNEL2(generate_random, real, int32_t base, uint32_t range, uint32_t, curand, \
464+
static_cast<real>(static_cast<int32_t>((x % range) + base)))
465+
GENERATE_KERNEL2(generate_random_64, real, int64_t base, uint64_t range, uint64_t, CURAND64, \
466+
static_cast<real>(static_cast<int64_t>((x % range) + base)))
465467
#elif defined(THC_REAL_IS_HALF)
466-
GENERATE_KERNEL2(generate_random, real, int32_t base, uint32_t range, uint32_t, curand, (ScalarConvert<uint32_t, real>::to(x % range + base)))
468+
GENERATE_KERNEL2(generate_random, real, int32_t base, uint32_t range, uint32_t, curand,
469+
(ScalarConvert<int32_t, real>::to(static_cast<int32_t>(x % range + base))))
467470
#else
468-
GENERATE_KERNEL2(generate_random, real, int32_t base, uint32_t range, uint32_t, curand, (real)(x % range + base))
471+
GENERATE_KERNEL2(generate_random, real, int32_t base, uint32_t range, uint32_t, curand,
472+
static_cast<real>(static_cast<int32_t>(x % range + base)))
469473
#endif
470474
471475
THC_API void THCTensor_(geometric)(THCState* state, THCTensor *self_, double p)

test/test_cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,9 @@ def test_nvtx(self):
15631563
torch.cuda.nvtx.mark("bar")
15641564
torch.cuda.nvtx.range_pop()
15651565

1566+
def test_random_neg_values(self):
1567+
TestTorch._test_random_neg_values(self, use_cuda=True)
1568+
15661569

15671570
def load_ignore_file():
15681571
from os.path import join, dirname

test/test_torch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,6 +2356,21 @@ def test_random(self):
23562356
self.assertEqual(t.min(), 0)
23572357
self.assertEqual(t.max(), ub - 1)
23582358

2359+
@staticmethod
2360+
def _test_random_neg_values(self, use_cuda=False):
2361+
signed_types = ['torch.DoubleTensor', 'torch.FloatTensor', 'torch.LongTensor',
2362+
'torch.IntTensor', 'torch.ShortTensor']
2363+
for tname in signed_types:
2364+
res = torch.rand(SIZE, SIZE).type(tname)
2365+
if use_cuda:
2366+
res = res.cuda()
2367+
res.random_(-10, -1)
2368+
self.assertLessEqual(res.max().item(), 9)
2369+
self.assertGreaterEqual(res.min().item(), -10)
2370+
2371+
def test_random_neg_values(self):
2372+
self._test_random_neg_values(self)
2373+
23592374
def assertIsOrdered(self, order, x, mxx, ixx, task):
23602375
SIZE = 4
23612376
if order == 'descending':

0 commit comments

Comments
 (0)