Skip to content

Commit a3bd7b2

Browse files
theweihoapaszke
authored andcommitted
Optimize unique sorting by using std::vector+sort instead of std::set (#5913)
1 parent 537e0e0 commit a3bd7b2

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

aten/src/ATen/native/Unique.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,24 @@ namespace native{
1313

1414
namespace {
1515

16-
template <template <class...> class set_type, typename scalar_t>
16+
template <typename scalar_t>
1717
std::tuple<Tensor, Tensor> _unique_cpu_template(
1818
const Tensor& self,
19+
const bool sorted,
1920
const bool return_inverse) {
2021
const Tensor& input = self.contiguous();
2122
const scalar_t* input_data = input.data<scalar_t>();
22-
set_type<scalar_t> set(input_data, input_data + input.numel());
23+
std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
2324
Tensor output = input.type().tensor({static_cast<int64_t>(set.size())});
2425
scalar_t* output_data = output.data<scalar_t>();
25-
std::copy(set.begin(), set.end(), output_data);
26+
27+
if (sorted) {
28+
std::vector<scalar_t> vec(set.begin(), set.end());
29+
std::sort(vec.begin(), vec.end());
30+
std::copy(vec.begin(), vec.end(), output_data);
31+
} else {
32+
std::copy(set.begin(), set.end(), output_data);
33+
}
2634

2735
Tensor inverse_indices = self.type().toScalarType(kLong).tensor({0});
2836
if (return_inverse) {
@@ -43,16 +51,9 @@ std::tuple<Tensor, Tensor> _unique_cpu_template(
4351

4452
std::tuple<Tensor, Tensor>
4553
_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
46-
if (sorted) {
47-
return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
48-
return _unique_cpu_template<std::set, scalar_t>(self, return_inverse);
49-
});
50-
} else {
51-
return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
52-
return _unique_cpu_template<std::unordered_set, scalar_t>(
53-
self, return_inverse);
54-
});
55-
}
54+
return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
55+
return _unique_cpu_template<scalar_t>(self, sorted, return_inverse);
56+
});
5657
}
5758

5859
} // namespace native

0 commit comments

Comments
 (0)