-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add torch.count_nonzero
#39992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.count_nonzero
#39992
Conversation
💊 CI failures summary and remediationsAs of commit 61a9937 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
|
Please review |
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @kshitij12345, thanks for the PR! I added some comments.
Summary + followup checklist:
- Suggestion to simplify test in test_torch.py
- Minor doc improvements
- Method variants
- Add a "not_implemented" entry to tools/autograd/derivatives.yaml
- Update docs/source/name_inference.rst
- Update docs/source/tensors.rst and docs/source/torch.rst
- Update torch/_torch_docs.py
Overall this is really good, it just needs a few updates.
|
@mruberry I believe have addressed the comments. Please review. |
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more test changes (sorry about that, I made a note to make self.compare_with_numpy easier to use!) and I suggest you remove the name variant to keep this PR more focused, otherwise you'll need to add tests for it, and that could be a separate PR if you wanted.
Also I made a note on the name_inference.rst doc.
* scale the randomly generated values to have bigger range for intergral. * pass the input data as tolist(), this will force compare_with_numpy, to obey the passed device and dtype.
|
#40064 clarifies the compare_with_numpy behavior, thank you for revealing an issue with it. You will have to rebase after it goes in, unfortunately, but the code change it requires is very small. We really want to avoid those tolist() calls. |
* add extremal case for floating dtypes.
|
PTAL:) |
|
Would be great if you can have another look at this. I am planning to use the Thank You. |
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work. This looks really good, @kshitij12345!
Should the test also consider complex dtypes?
|
Have added complex dtypes to the test as well. |
Thanks kshitij12345! We just need to wait for the tests again since there was a merge conflict with master. |
|
Another merge conflict, unfortunately. We'll have to wait for the tests again. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@mruberry Gentle ping :) |
Thanks for the ping. I'm fighting with our internal landing system, unfortunately. |
|
Oh. Sure. Thank You! :) |
Reference #38349
TODO: