@@ -13,16 +13,24 @@ namespace native{
1313
1414namespace {
1515
16- template <template < class ...> class set_type , typename scalar_t >
16+ template <typename scalar_t >
1717std::tuple<Tensor, Tensor> _unique_cpu_template (
1818 const Tensor& self,
19+ const bool sorted,
1920 const bool return_inverse) {
2021 const Tensor& input = self.contiguous ();
2122 const scalar_t * input_data = input.data <scalar_t >();
22- set_type <scalar_t > set (input_data, input_data + input.numel ());
23+ std::unordered_set <scalar_t > set (input_data, input_data + input.numel ());
2324 Tensor output = input.type ().tensor ({static_cast <int64_t >(set.size ())});
2425 scalar_t * output_data = output.data <scalar_t >();
25- std::copy (set.begin (), set.end (), output_data);
26+
27+ if (sorted) {
28+ std::vector<scalar_t > vec (set.begin (), set.end ());
29+ std::sort (vec.begin (), vec.end ());
30+ std::copy (vec.begin (), vec.end (), output_data);
31+ } else {
32+ std::copy (set.begin (), set.end (), output_data);
33+ }
2634
2735 Tensor inverse_indices = self.type ().toScalarType (kLong ).tensor ({0 });
2836 if (return_inverse) {
@@ -43,16 +51,9 @@ std::tuple<Tensor, Tensor> _unique_cpu_template(
4351
4452std::tuple<Tensor, Tensor>
4553_unique_cpu (const Tensor& self, const bool sorted, const bool return_inverse) {
46- if (sorted) {
47- return AT_DISPATCH_ALL_TYPES (self.type (), " unique" , [&] {
48- return _unique_cpu_template<std::set, scalar_t >(self, return_inverse);
49- });
50- } else {
51- return AT_DISPATCH_ALL_TYPES (self.type (), " unique" , [&] {
52- return _unique_cpu_template<std::unordered_set, scalar_t >(
53- self, return_inverse);
54- });
55- }
54+ return AT_DISPATCH_ALL_TYPES (self.type (), " unique" , [&] {
55+ return _unique_cpu_template<scalar_t >(self, sorted, return_inverse);
56+ });
5657}
5758
5859} // namespace native
0 commit comments