Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 77 additions & 3 deletions aten/src/ATen/native/cuda/Unique.cu
Original file line number Diff line number Diff line change
@@ -1,15 +1,89 @@
#include "ATen/ATen.h"

#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/execution_policy.h>

#include <tuple>
#include <thrust/unique.h>
#include <thrust/sort.h>

namespace at {
namespace native{

#ifndef __HIP_PLATFORM_HCC__

namespace {
template <typename scalar_t>
__global__ void inverse_indices_kernel(
const scalar_t* input_data,
const scalar_t* output_data,
int64_t* inverse_indices_data,
int64_t num_inp,
int64_t num_out) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride = blockDim.x * gridDim.x;

for (int64_t i = idx; i < num_inp * num_out; i += stride) {
if (input_data[i / num_out] == output_data[i % num_out]){
inverse_indices_data[i / num_out] = i % num_out;
}
}
}


template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_cuda_template(
const Tensor& self,
const bool return_inverse) {

cudaStream_t stream = globalContext().getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);

const Tensor& input = self.contiguous();
int64_t num_inp = input.numel();
const scalar_t* input_data = input.data<scalar_t>();

//sort & unique
Tensor output = input.clone();
output = output.view(-1);
scalar_t* output_data = output.data<scalar_t>();
thrust::sort(policy, output_data, output_data + num_inp);
scalar_t* output_end = thrust::unique(policy, output_data, output_data + num_inp);
int64_t num_out = output_end - output_data;
output.resize_(num_out);

Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));

if (return_inverse) {
inverse_indices.resize_(input.sizes());
int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
int block = 512;
int grid = std::min<int64_t>((num_inp * num_out + block - 1) / block, 2048L);
inverse_indices_kernel<<<grid, block, 0, stream>>>(
input_data, output_data, inverse_indices_data, num_inp, num_out);
}

THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor>(output, inverse_indices);

}
} // namespace

#endif

std::tuple<Tensor, Tensor>
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
throw std::runtime_error(
"unique is currently CPU-only, and lacks CUDA support. "
"Pull requests welcome!");
#ifndef __HIP_PLATFORM_HCC__
return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
return _unique_cuda_template<scalar_t>(self, return_inverse);
});
#else
AT_ERROR("unique_cuda: HIP not supported");
#endif
}

} // namespace native
Expand Down
12 changes: 1 addition & 11 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7580,7 +7580,7 @@ def test_set_flush_denormal(self):
self.assertEqual(double_tensor[2], 0.0, prec=0.0) # tiny_double to zero
torch.set_flush_denormal(False)

def test_unique_cpu(self):
def test_unique(self):
x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3])
expected_unique = torch.LongTensor([1, 2, 3, 5, 8])
expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2])
Expand Down Expand Up @@ -7630,16 +7630,6 @@ def test_unique_cpu(self):
self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)

@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
def test_unique_cuda(self):
# unique currently does not support CUDA.
self.assertRaises(
RuntimeError, lambda: torch.cuda.LongTensor([0, 1]).unique())
self.assertRaises(
RuntimeError,
lambda: torch.unique(torch.cuda.FloatTensor([0., 1.])),
)

@staticmethod
def _test_bincount(self, device):
# negative input throws
Expand Down