Skip to content

Commit 0f7c371

Browse files
xuhdevfacebook-github-bot
authored andcommitted
Support Half type in randperm.
Summary: Pull Request resolved: #22102 Test Plan: Imported from OSS Differential Revision: D16153586 Pulled By: li-roy fbshipit-source-id: d58e3dbc5da893005f4eaf521a28b0d752274eff
1 parent 9c4c9c3 commit 0f7c371

File tree

5 files changed

+41
-5
lines changed

5 files changed

+41
-5
lines changed

aten/src/ATen/native/TensorFactories.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,11 +525,12 @@ Tensor& randperm_out(Tensor& result, int64_t n) {
525525

526526
Tensor& randperm_out_cpu(Tensor& result, int64_t n, Generator* generator) {
527527
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
528+
check_supported_max_int_with_precision(n, result);
528529
result.resize_({n});
529530
auto gen = get_generator_or_default<CPUGenerator>(generator, detail::getDefaultCPUGenerator());
530531
// See Note [Acquire lock when using random generators]
531532
std::lock_guard<std::mutex> lock(gen->mutex_);
532-
AT_DISPATCH_ALL_TYPES(result.scalar_type(), "randperm", [&]() -> void {
533+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.scalar_type(), "randperm", [&]() -> void {
533534
randperm_cpu<scalar_t>(result, n, gen);
534535
});
535536

aten/src/ATen/native/TensorFactories.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,25 @@ inline void check_size_nonnegative(IntArrayRef size) {
6464
TORCH_CHECK(x >= 0, "Trying to create tensor with negative dimension ", x, ": ", size);
6565
}
6666
}
67+
68+
inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
69+
TORCH_CHECK(at::scalar_tensor(n, tensor.options()).defined(),
70+
"n is too large for result tensor type: '", tensor.type().toString(), "'");
71+
72+
// Ensure sufficient precision for floating point representation.
73+
switch (tensor.scalar_type()) {
74+
case at::ScalarType::Half:
75+
TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
76+
break;
77+
case at::ScalarType::Float:
78+
TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
79+
break;
80+
case at::ScalarType::Double: // Unlikely to happen, but doesn't hurt to check
81+
TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
82+
break;
83+
default:
84+
break;
85+
}
86+
}
6787
} // namespace native
6888
} // namespace at

aten/src/ATen/native/cuda/TensorFactories.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, const TensorOpti
7878

7979
Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
8080
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
81-
TORCH_CHECK(at::scalar_tensor(n, result.options()).defined(),
82-
"n is too large for result tensor type: '", result.type().toString(), "'");
81+
check_supported_max_int_with_precision(n, result);
8382

8483
result.resize_({n});
8584

test/test_cuda.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2790,8 +2790,8 @@ def test_randperm_cuda(self):
27902790
self.assertEqual(res1, res2, 0)
27912791

27922792
with torch.random.fork_rng(devices=[0]):
2793-
res1 = torch.randperm(50000, dtype=torch.half, device=cuda)
2794-
res2 = torch.cuda.HalfTensor()
2793+
res1 = torch.randperm(50000, dtype=torch.float, device=cuda)
2794+
res2 = torch.cuda.FloatTensor()
27952795
torch.randperm(50000, out=res2, device=cuda)
27962796
self.assertEqual(res1, res2, 0)
27972797

@@ -2802,6 +2802,14 @@ def test_randperm_cuda(self):
28022802
self.assertEqual(res1.numel(), 0)
28032803
self.assertEqual(res2.numel(), 0)
28042804

2805+
# Test exceptions when n is too large for a floating point type
2806+
for res, small_n, large_n in ((torch.cuda.HalfTensor(), 2**11 + 1, 2**11 + 2),
2807+
(torch.cuda.FloatTensor(), 2**24 + 1, 2**24 + 2),
2808+
(torch.cuda.DoubleTensor(), 2**25, # 2**53 + 1 is too large to run
2809+
2**53 + 2)):
2810+
torch.randperm(small_n, out=res) # No exception expected
2811+
self.assertRaises(RuntimeError, lambda: torch.randperm(large_n, out=res))
2812+
28052813
def test_random_neg_values(self):
28062814
_TestTorchMixin._test_random_neg_values(self, use_cuda=True)
28072815

test/test_torch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4512,6 +4512,14 @@ def test_randperm(self):
45124512
self.assertEqual(res1.numel(), 0)
45134513
self.assertEqual(res2.numel(), 0)
45144514

4515+
# Test exceptions when n is too large for a floating point type
4516+
for res, small_n, large_n in ((torch.HalfTensor(), 2**11 + 1, 2**11 + 2),
4517+
(torch.FloatTensor(), 2**24 + 1, 2**24 + 2),
4518+
(torch.DoubleTensor(), 2**25, # 2**53 + 1 is too large to run
4519+
2**53 + 2)):
4520+
torch.randperm(small_n, out=res) # No exception expected
4521+
self.assertRaises(RuntimeError, lambda: torch.randperm(large_n, out=res))
4522+
45154523
def test_random(self):
45164524
# This test is flaky with p<=(2/(ub-lb))^200=6e-36
45174525
t = torch.FloatTensor(200)

0 commit comments

Comments
 (0)