-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add per-element unique op for CPU #5503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
0904426
Initial commit for unique op
theweiho a8d5ffa
Working unique with test
theweiho 55a8b5a
Make inverse indices shape conform to input
theweiho a197a7b
flake8 whitespace removal
theweiho 3c26309
address review comment nits
theweiho b48b427
Expose fn and add docs. Explicitly declare no gradients
theweiho 00244e8
Trial generic dispatch implementation
theweiho bd49d0a
Add tests for generics
theweiho a3dc7b2
flake8 whitespace
theweiho a016ea2
Add basic CUDA error throwing and templateize set
theweiho 294ed5d
Explicit contiguous and AT_DISPATCH_ALL_TYPES return
theweiho 6534579
Remove extraneous numpy conversion
theweiho 69c61f4
Refactor out .data calls
theweiho 7ae734b
Refactored to variable return length API with wrapper fn as opposed t…
theweiho b07dcc3
Remove A
theweiho 6c1b125
Don't use hidden torch._unique() in test
theweiho eeebf16
Fix documentations
theweiho File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| // Returns unique elements of input tensor. | ||
|
|
||
| #include "ATen/ATen.h" | ||
| #include "ATen/Dispatch.h" | ||
|
|
||
| #include <set> | ||
| #include <tuple> | ||
| #include <unordered_map> | ||
| #include <unordered_set> | ||
|
|
||
| namespace at { | ||
| namespace native{ | ||
|
|
||
| namespace { | ||
|
|
||
| template <template <class...> class set_type, typename scalar_t> | ||
| std::tuple<Tensor, Tensor> _unique_cpu_template( | ||
| const Tensor& self, | ||
| const bool return_inverse) { | ||
| const Tensor& input = self.contiguous(); | ||
| const scalar_t* input_data = input.data<scalar_t>(); | ||
| set_type<scalar_t> set(input_data, input_data + input.numel()); | ||
| Tensor output = input.type().tensor({static_cast<int64_t>(set.size())}); | ||
| scalar_t* output_data = output.data<scalar_t>(); | ||
| std::copy(set.begin(), set.end(), output_data); | ||
|
|
||
| Tensor inverse_indices = self.type().toScalarType(kLong).tensor({0}); | ||
| if (return_inverse) { | ||
| inverse_indices.resize_(input.sizes()); | ||
| int64_t* inverse_indices_data = inverse_indices.data<int64_t>(); | ||
| std::unordered_map<scalar_t, int64_t> inverse_map; | ||
| inverse_map.reserve(output.numel()); | ||
| for (int i = 0; i < output.numel(); ++i) { | ||
| inverse_map[output_data[i]] = i; | ||
| } | ||
| for (int i = 0; i < input.numel(); ++i) { | ||
| inverse_indices_data[i] = inverse_map[input_data[i]]; | ||
| } | ||
| } | ||
| return std::make_tuple(output, inverse_indices); | ||
| } | ||
| } // namespace | ||
|
|
||
| std::tuple<Tensor, Tensor> | ||
| _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) { | ||
| if (sorted) { | ||
| return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] { | ||
| return _unique_cpu_template<std::set, scalar_t>(self, return_inverse); | ||
| }); | ||
| } else { | ||
| return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] { | ||
| return _unique_cpu_template<std::unordered_set, scalar_t>( | ||
| self, return_inverse); | ||
| }); | ||
| } | ||
| } | ||
|
|
||
| } // namespace native | ||
| } // namespace at |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| #include "ATen/ATen.h" | ||
|
|
||
| #include <tuple> | ||
|
|
||
| namespace at { | ||
| namespace native{ | ||
|
|
||
| std::tuple<Tensor, Tensor> | ||
| _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { | ||
| throw std::runtime_error( | ||
| "unique is currently CPU-only, and lacks CUDA support. " | ||
| "Pull requests welcome!"); | ||
| } | ||
|
|
||
| } // namespace native | ||
| } // namespace at |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.