-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add matrix_rank #10338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add matrix_rank #10338
Conversation
- Similar functionality as NumPy - Added doc string - Added tests
| return V.mm(S_pseudoinv.diag().mm(U.t())); | ||
| } | ||
|
|
||
| static double _get_epsilon(const ScalarType& sc_type) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| "of floating types"); | ||
|
|
||
| Tensor S = _matrix_rank_helper(self, symmetric); | ||
| return (S > tol).toType(kLong).sum().toCLong(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@ssnl Is this good to go? |
| "of floating types"); | ||
|
|
||
| Tensor S = _matrix_rank_helper(self, symmetric); | ||
| return (S > tol).toType(kLong).sum(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return V.mm(S_pseudoinv.diag().mm(U.t())); | ||
| } | ||
|
|
||
| static double _get_epsilon(const ScalarType& sc_type) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| "of floating types"); | ||
|
|
||
| Tensor S = _matrix_rank_helper(self, symmetric); | ||
| double tol = S.max().toCDouble() * std::max<double>(self.size(0), self.size(1)) * |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
…nto matrix-rank
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| return V.mm(S_pseudoinv.diag().mm(U.t())); | ||
| } | ||
|
|
||
| Tensor _matrix_rank_helper(const Tensor& self, bool symmetric) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| "of floating types"); | ||
|
|
||
| Tensor S = _matrix_rank_helper(self, symmetric); | ||
| Tensor tol = S.max() * _get_epsilon(self.type()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/_torch_docs.py
Outdated
| :attr:`tol` is the threshold below which the singular values (or the eigenvalues | ||
| when :attr:`symmetric` is ``True``) are considered to be 0. If :attr:`tol` is not | ||
| specified, :attr:`tol` is set to `S.max() * max(S.size()) * eps` where `S` is the |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/_torch_docs.py
Outdated
| >>> a = torch.eye(10) | ||
| >>> torch.matrix_rank(a) | ||
| 10 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| static inline Tensor _get_epsilon(const Type& type) { | ||
| switch (type.scalarType()) { | ||
| case at::ScalarType::Half: | ||
| return type.tensor({}).fill_(std::numeric_limits<at::Half>::epsilon()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
goldsborough
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be a future fix.
|
|
||
| Tensor S = _matrix_rank_helper(self, symmetric); | ||
| Tensor tol = _get_epsilon(self.type()) * std::max(self.size(0), self.size(1)); | ||
| tol.mul_(S.max()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
torch/_torch_docs.py
Outdated
| :attr:`tol` is the threshold below which the singular values (or the eigenvalues | ||
| when :attr:`symmetric` is ``True``) are considered to be 0. If :attr:`tol` is not | ||
| specified, :attr:`tol` is set to ``S.max() * max(S.size()) * eps`` where `S` is the | ||
| singular values (or the eigenvalues when :attr:`symmetric` is ``True``), and `eps` |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SsnL has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@ssnl is this good to go? |
|
@ssnl ping :) |
|
I tried to get another person to look before merging but wasn't able to. Will get this going tomorrow. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SsnL is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong') | ||
|
|
||
| @staticmethod | ||
| def _test_matrix_rank(self, conv_fn): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SsnL has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@ssnl is this good to go? |
|
It’s failing internal test. I’ve no idea why yet, but I’ll look this
afternoon.
…On Fri, Aug 17, 2018 at 12:40 Vishwak Srinivasan ***@***.***> wrote:
@ssnl <https://github.com/SsnL> is this good to go?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#10338 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AFaWZb8d-qn1fW7RjAUpfNUKWLmz97vsks5uRvIHgaJpZM4VzIvZ>
.
|
|
@vishwakftw There is some really weird things going on. I tried but can't quite repro or figure out why. I will talk with people knowing more about those tests on Monday. Sorry about it. |
|
Oh no worries. Please take your time. After all, it’s a weekend :) |
|
@vishwakftw the 3rd arg of |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SsnL has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - Similar functionality as NumPy - Added doc string - Added tests Differential Revision: D9240850 Pulled By: SsnL fbshipit-source-id: 1d04cfadb076e99e03bdf699bc41b8fac06831bf
Summary: - Similar functionality as NumPy - Added doc string - Added tests Differential Revision: D9240850 Pulled By: SsnL fbshipit-source-id: 1d04cfadb076e99e03bdf699bc41b8fac06831bf
cc: @ssnl
Closes #10292