Skip to content

Commit edfcbfb

Browse files
authored
Implement randperm for CUDA (pytorch#7606)
* Implement randperm for CUDA * Use Thrust to implement randperm * clean up * Fix test * Offload small input scenario to CPU * Fixed test * Try to fix Windows error * Fix Windows error and clean up * Use fork_rng context manager * Move test_randperm_cuda to test_cuda * Add half tensor support * Fix cuda::type error * Fix CPU offloading * Fix issues * No need to check range for n == 0 case
1 parent 9af3a80 commit edfcbfb

File tree

4 files changed

+103
-5
lines changed

4 files changed

+103
-5
lines changed

aten/src/ATen/native/TensorFactories.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,18 +263,15 @@ THGenerator* get_generator(at::Generator* gen) {
263263

264264
Tensor randperm(const Type& dtype, int64_t n, Generator* generator) {
265265
Tensor result = dtype.tensor(n);
266-
return at::native::randperm_out(result, n, generator);
266+
return at::randperm_out(result, n, generator);
267267
}
268268

269-
Tensor& randperm_out(Tensor& result, int64_t n, Generator* generator) {
269+
Tensor& randperm_out_cpu(Tensor& result, int64_t n, Generator* generator) {
270270
if (n < 0) {
271271
std::ostringstream oss;
272272
oss << "n must be non-negative, got " << n;
273273
throw std::runtime_error(oss.str());
274274
}
275-
if (result.type().backend() != at::kCPU) {
276-
throw std::runtime_error("randperm is only implemented for CPU");
277-
}
278275

279276
result.resize_({n});
280277
auto gen = get_generator(generator);

aten/src/ATen/native/cuda/TensorFactories.cu

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1+
#include "ATen/ATen.h"
12
#include "ATen/NativeFunctions.h"
3+
#include "ATen/cuda/CUDATypeConversion.cuh"
4+
5+
#include <THC/THCGeneral.h>
6+
#include <THC/THCThrustAllocator.cuh>
7+
#include <thrust/device_ptr.h>
8+
#include <thrust/sort.h>
9+
#include <thrust/execution_policy.h>
10+
#include <thrust/sequence.h>
11+
212
#include <algorithm>
313
#include <sstream>
414

@@ -26,4 +36,57 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
2636
return result;
2737
}
2838

39+
Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
40+
if (n < 0) {
41+
std::ostringstream oss;
42+
oss << "n must be non-negative, got " << n;
43+
throw std::runtime_error(oss.str());
44+
}
45+
46+
if (n > 0) {
47+
AT_DISPATCH_ALL_TYPES_AND_HALF(
48+
result.type(), "randperm_out_cuda", [&] {
49+
AT_CHECK(Scalar(n).to<scalar_t>(),
50+
"n is too large for result tensor type: '", result.type().toString(), "'");
51+
}
52+
);
53+
}
54+
55+
result.resize_({n});
56+
57+
if (result.type().scalarType() == at::ScalarType::Half) {
58+
auto result_float = CUDA(kFloat).tensor({n});
59+
result.copy_(randperm_out_cuda(result_float, n, generator));
60+
} else {
61+
if (n < 30000) { // For small inputs, we offload it to CPU instead.
62+
auto result_cpu = result.type().toBackend(kCPU).tensor({n});
63+
randperm_out(result_cpu, n, generator);
64+
result.copy_(result_cpu);
65+
} else {
66+
// Generate random values for the keys array
67+
AT_DISPATCH_ALL_TYPES(
68+
result.type(), "randperm_out_cuda", [&] {
69+
using cuda_scalar_t = cuda::into_type<scalar_t>;
70+
71+
auto keys = result.type().tensor(result.sizes()).random_(generator);
72+
73+
auto result_data = thrust::device_ptr<cuda_scalar_t>(result.data<cuda_scalar_t>());
74+
auto keys_data = thrust::device_ptr<cuda_scalar_t>(keys.data<cuda_scalar_t>());
75+
76+
auto state = globalContext().getTHCState();
77+
THCThrustAllocator thrustAlloc(state);
78+
auto policy = thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state));
79+
80+
thrust::sequence(policy, result_data, result_data + n);
81+
82+
// Use the sorted order of keys to rearrange the result array
83+
thrust::sort_by_key(policy, keys_data, keys_data + n, result_data);
84+
}
85+
);
86+
}
87+
}
88+
89+
return result;
90+
}
91+
2992
}} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,9 @@
800800

801801
- func: randperm_out(Tensor result, int64_t n, *, Generator* generator=nullptr) -> Tensor
802802
variants: function
803+
dispatch:
804+
CPU: randperm_out_cpu
805+
CUDA: randperm_out_cuda
803806

804807
- func: range(Type dtype, Scalar start, Scalar end, Scalar step=1) -> Tensor
805808
variants: function

test/test_cuda.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,41 @@ def test_nvtx(self):
16531653
torch.cuda.nvtx.mark("bar")
16541654
torch.cuda.nvtx.range_pop()
16551655

1656+
def test_randperm_cuda(self):
1657+
cuda = torch.device('cuda:0')
1658+
1659+
# For small inputs, randperm is offloaded to CPU instead
1660+
with torch.random.fork_rng(devices=[0]):
1661+
res1 = torch.randperm(100, device=cuda)
1662+
res2 = torch.cuda.LongTensor()
1663+
torch.randperm(100, out=res2, device=cuda)
1664+
self.assertEqual(res1, res2, 0)
1665+
1666+
with torch.random.fork_rng(devices=[0]):
1667+
res1 = torch.randperm(100000, device=cuda)
1668+
res2 = torch.cuda.LongTensor()
1669+
torch.randperm(100000, out=res2, device=cuda)
1670+
self.assertEqual(res1, res2, 0)
1671+
1672+
with torch.random.fork_rng(devices=[0]):
1673+
res1 = torch.randperm(100, dtype=torch.half, device=cuda)
1674+
res2 = torch.cuda.HalfTensor()
1675+
torch.randperm(100, out=res2, device=cuda)
1676+
self.assertEqual(res1, res2, 0)
1677+
1678+
with torch.random.fork_rng(devices=[0]):
1679+
res1 = torch.randperm(50000, dtype=torch.half, device=cuda)
1680+
res2 = torch.cuda.HalfTensor()
1681+
torch.randperm(50000, out=res2, device=cuda)
1682+
self.assertEqual(res1, res2, 0)
1683+
1684+
# randperm of 0 elements is an empty tensor
1685+
res1 = torch.randperm(0, device=cuda)
1686+
res2 = torch.cuda.LongTensor(5)
1687+
torch.randperm(0, out=res2, device=cuda)
1688+
self.assertEqual(res1.numel(), 0)
1689+
self.assertEqual(res2.numel(), 0)
1690+
16561691
def test_random_neg_values(self):
16571692
TestTorch._test_random_neg_values(self, use_cuda=True)
16581693

0 commit comments

Comments
 (0)