Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,15 @@ THGenerator* get_generator(at::Generator* gen) {

Tensor randperm(const Type& dtype, int64_t n, Generator* generator) {
Tensor result = dtype.tensor(n);
return at::native::randperm_out(result, n, generator);
return at::randperm_out(result, n, generator);
}

Tensor& randperm_out(Tensor& result, int64_t n, Generator* generator) {
Tensor& randperm_out_cpu(Tensor& result, int64_t n, Generator* generator) {
if (n < 0) {
std::ostringstream oss;
oss << "n must be non-negative, got " << n;
throw std::runtime_error(oss.str());
}
if (result.type().backend() != at::kCPU) {
throw std::runtime_error("randperm is only implemented for CPU");
}

result.resize_({n});
auto gen = get_generator(generator);
Expand Down
63 changes: 63 additions & 0 deletions aten/src/ATen/native/cuda/TensorFactories.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/cuda/CUDATypeConversion.cuh"

#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>

#include <algorithm>
#include <sstream>

Expand Down Expand Up @@ -26,4 +36,57 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
return result;
}

Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
if (n < 0) {

This comment was marked as off-topic.

std::ostringstream oss;
oss << "n must be non-negative, got " << n;
throw std::runtime_error(oss.str());
}

if (n > 0) {
AT_DISPATCH_ALL_TYPES_AND_HALF(
result.type(), "randperm_out_cuda", [&] {
AT_CHECK(Scalar(n).to<scalar_t>(),
"n is too large for result tensor type: '", result.type().toString(), "'");
}
);
}

result.resize_({n});

if (result.type().scalarType() == at::ScalarType::Half) {
auto result_float = CUDA(kFloat).tensor({n});
result.copy_(randperm_out_cuda(result_float, n, generator));
} else {
if (n < 30000) { // For small inputs, we offload it to CPU instead.
auto result_cpu = result.type().toBackend(kCPU).tensor({n});
randperm_out(result_cpu, n, generator);
result.copy_(result_cpu);
} else {
// Generate random values for the keys array
AT_DISPATCH_ALL_TYPES(
result.type(), "randperm_out_cuda", [&] {
using cuda_scalar_t = cuda::into_type<scalar_t>;

auto keys = result.type().tensor(result.sizes()).random_(generator);

This comment was marked as off-topic.


auto result_data = thrust::device_ptr<cuda_scalar_t>(result.data<cuda_scalar_t>());
auto keys_data = thrust::device_ptr<cuda_scalar_t>(keys.data<cuda_scalar_t>());

auto state = globalContext().getTHCState();
THCThrustAllocator thrustAlloc(state);
auto policy = thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state));

thrust::sequence(policy, result_data, result_data + n);

// Use the sorted order of keys to rearrange the result array
thrust::sort_by_key(policy, keys_data, keys_data + n, result_data);
}
);
}
}

return result;
}

}} // namespace at::native
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,9 @@

- func: randperm_out(Tensor result, int64_t n, *, Generator* generator=nullptr) -> Tensor
variants: function
dispatch:
CPU: randperm_out_cpu
CUDA: randperm_out_cuda

- func: range(Type dtype, Scalar start, Scalar end, Scalar step=1) -> Tensor
variants: function
Expand Down
35 changes: 35 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,41 @@ def test_nvtx(self):
torch.cuda.nvtx.mark("bar")
torch.cuda.nvtx.range_pop()

def test_randperm_cuda(self):
cuda = torch.device('cuda:0')

# For small inputs, randperm is offloaded to CPU instead
with torch.random.fork_rng(devices=[0]):
res1 = torch.randperm(100, device=cuda)
res2 = torch.cuda.LongTensor()
torch.randperm(100, out=res2, device=cuda)
self.assertEqual(res1, res2, 0)

with torch.random.fork_rng(devices=[0]):
res1 = torch.randperm(100000, device=cuda)
res2 = torch.cuda.LongTensor()
torch.randperm(100000, out=res2, device=cuda)
self.assertEqual(res1, res2, 0)

with torch.random.fork_rng(devices=[0]):
res1 = torch.randperm(100, dtype=torch.half, device=cuda)
res2 = torch.cuda.HalfTensor()
torch.randperm(100, out=res2, device=cuda)
self.assertEqual(res1, res2, 0)

with torch.random.fork_rng(devices=[0]):
res1 = torch.randperm(50000, dtype=torch.half, device=cuda)
res2 = torch.cuda.HalfTensor()
torch.randperm(50000, out=res2, device=cuda)
self.assertEqual(res1, res2, 0)

# randperm of 0 elements is an empty tensor
res1 = torch.randperm(0, device=cuda)
res2 = torch.cuda.LongTensor(5)
torch.randperm(0, out=res2, device=cuda)
self.assertEqual(res1.numel(), 0)
self.assertEqual(res2.numel(), 0)

def test_random_neg_values(self):
TestTorch._test_random_neg_values(self, use_cuda=True)

Expand Down