Skip to content

Commit a656bd0

Browse files
author
Will Feng
committed
Add half tensor support
1 parent 90ca62d commit a656bd0

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

aten/src/ATen/native/cuda/TensorFactories.cu

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "ATen/ATen.h"
22
#include "ATen/NativeFunctions.h"
3+
#include "ATen/cuda/CUDATypeConversion.cuh"
34

45
#include <THC/THCGeneral.h>
56
#include <THC/THCThrustAllocator.cuh>
@@ -42,27 +43,44 @@ Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
4243
throw std::runtime_error(oss.str());
4344
}
4445

45-
result.resize_({n});
46+
auto result_tmp = result;
47+
if (result.type().scalarType() == at::ScalarType::Half) {
48+
// Make sure n is within range of Half
49+
assert(Scalar(n).toHalf());
50+
result_tmp = CUDA(kFloat).tensor();
51+
}
52+
result_tmp.resize_({n});
4653

4754
if (n < 30000) { // For small inputs, we offload it to CPU instead.
48-
auto result_cpu = result.type().toBackend(kCPU).tensor({n});
55+
auto result_cpu = result_tmp.type().toBackend(kCPU).tensor({n});
4956
randperm_out(result_cpu, n, generator);
50-
result = result.type().copy(result_cpu);
57+
result_tmp = result_tmp.type().copy(result_cpu);
5158
} else {
5259
// Generate random values for the keys array
53-
auto keys = result.type().tensor(result.sizes()).random_(generator);
60+
AT_DISPATCH_ALL_TYPES(
61+
result_tmp.type(), "randperm_out_cuda", [&] {
62+
using cuda_scalar_t = cuda::type<scalar_t>;
63+
64+
auto keys = result_tmp.type().tensor(result_tmp.sizes()).random_(generator);
5465

55-
auto result_data = thrust::device_ptr<int64_t>(result.data<int64_t>());
56-
auto keys_data = thrust::device_ptr<int64_t>(keys.data<int64_t>());
66+
auto result_data = thrust::device_ptr<cuda_scalar_t>(result_tmp.data<cuda_scalar_t>());
67+
auto keys_data = thrust::device_ptr<cuda_scalar_t>(keys.data<cuda_scalar_t>());
5768

58-
auto state = globalContext().getTHCState();
59-
THCThrustAllocator thrustAlloc(state);
60-
auto policy = thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state));
69+
auto state = globalContext().getTHCState();
70+
THCThrustAllocator thrustAlloc(state);
71+
auto policy = thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state));
6172

62-
thrust::sequence(policy, result_data, result_data + n);
73+
thrust::sequence(policy, result_data, result_data + n);
74+
75+
// Use the sorted order of keys to rearrange the result array
76+
thrust::sort_by_key(policy, keys_data, keys_data + n, result_data);
77+
}
78+
);
79+
}
6380

64-
// Use the sorted order of keys to rearrange the result array
65-
thrust::sort_by_key(policy, keys_data, keys_data + n, result_data);
81+
if (result.type().scalarType() == at::ScalarType::Half) {
82+
result.resize_({n});
83+
result.copy_(result_tmp);
6684
}
6785

6886
return result;

test/test_cuda.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,6 +1652,18 @@ def test_randperm_cuda(self):
16521652
torch.randperm(100000, out=res2, device=cuda)
16531653
self.assertEqual(res1, res2, 0)
16541654

1655+
with torch.random.fork_rng(devices=[0]):
1656+
res1 = torch.randperm(100, dtype=torch.half, device=cuda)
1657+
res2 = torch.cuda.HalfTensor()
1658+
torch.randperm(100, out=res2, device=cuda)
1659+
self.assertEqual(res1, res2, 0)
1660+
1661+
with torch.random.fork_rng(devices=[0]):
1662+
res1 = torch.randperm(50000, dtype=torch.half, device=cuda)
1663+
res2 = torch.cuda.HalfTensor()
1664+
torch.randperm(50000, out=res2, device=cuda)
1665+
self.assertEqual(res1, res2, 0)
1666+
16551667
# randperm of 0 elements is an empty tensor
16561668
res1 = torch.randperm(0, device=cuda)
16571669
res2 = torch.cuda.LongTensor(5)

0 commit comments

Comments
 (0)