|
1 | 1 | #include "ATen/ATen.h" |
2 | 2 | #include "ATen/NativeFunctions.h" |
| 3 | +#include "ATen/cuda/CUDATypeConversion.cuh" |
3 | 4 |
|
4 | 5 | #include <THC/THCGeneral.h> |
5 | 6 | #include <THC/THCThrustAllocator.cuh> |
@@ -42,27 +43,44 @@ Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) { |
42 | 43 | throw std::runtime_error(oss.str()); |
43 | 44 | } |
44 | 45 |
|
45 | | - result.resize_({n}); |
| 46 | + auto result_tmp = result; |
| 47 | + if (result.type().scalarType() == at::ScalarType::Half) { |
| 48 | + // Make sure n is within range of Half |
| 49 | + assert(Scalar(n).toHalf()); |
| 50 | + result_tmp = CUDA(kFloat).tensor(); |
| 51 | + } |
| 52 | + result_tmp.resize_({n}); |
46 | 53 |
|
47 | 54 | if (n < 30000) { // For small inputs, we offload it to CPU instead. |
48 | | - auto result_cpu = result.type().toBackend(kCPU).tensor({n}); |
| 55 | + auto result_cpu = result_tmp.type().toBackend(kCPU).tensor({n}); |
49 | 56 | randperm_out(result_cpu, n, generator); |
50 | | - result = result.type().copy(result_cpu); |
| 57 | + result_tmp = result_tmp.type().copy(result_cpu); |
51 | 58 | } else { |
52 | 59 | // Generate random values for the keys array |
53 | | - auto keys = result.type().tensor(result.sizes()).random_(generator); |
| 60 | + AT_DISPATCH_ALL_TYPES( |
| 61 | + result_tmp.type(), "randperm_out_cuda", [&] { |
| 62 | + using cuda_scalar_t = cuda::type<scalar_t>; |
| 63 | + |
| 64 | + auto keys = result_tmp.type().tensor(result_tmp.sizes()).random_(generator); |
54 | 65 |
|
55 | | - auto result_data = thrust::device_ptr<int64_t>(result.data<int64_t>()); |
56 | | - auto keys_data = thrust::device_ptr<int64_t>(keys.data<int64_t>()); |
| 66 | + auto result_data = thrust::device_ptr<cuda_scalar_t>(result_tmp.data<cuda_scalar_t>()); |
| 67 | + auto keys_data = thrust::device_ptr<cuda_scalar_t>(keys.data<cuda_scalar_t>()); |
57 | 68 |
|
58 | | - auto state = globalContext().getTHCState(); |
59 | | - THCThrustAllocator thrustAlloc(state); |
60 | | - auto policy = thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)); |
| 69 | + auto state = globalContext().getTHCState(); |
| 70 | + THCThrustAllocator thrustAlloc(state); |
| 71 | + auto policy = thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)); |
61 | 72 |
|
62 | | - thrust::sequence(policy, result_data, result_data + n); |
| 73 | + thrust::sequence(policy, result_data, result_data + n); |
| 74 | + |
| 75 | + // Use the sorted order of keys to rearrange the result array |
| 76 | + thrust::sort_by_key(policy, keys_data, keys_data + n, result_data); |
| 77 | + } |
| 78 | + ); |
| 79 | + } |
63 | 80 |
|
64 | | - // Use the sorted order of keys to rearrange the result array |
65 | | - thrust::sort_by_key(policy, keys_data, keys_data + n, result_data); |
| 81 | + if (result.type().scalarType() == at::ScalarType::Half) { |
| 82 | + result.resize_({n}); |
| 83 | + result.copy_(result_tmp); |
66 | 84 | } |
67 | 85 |
|
68 | 86 | return result; |
|
0 commit comments