-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Update backward formula for torch.dot and add backward definition for torch.vdot #45074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… torch.vdot [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit bc6a696 (more details on the Dr. CI page):
🕵️ 6 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
| ('addr', (S, M), ((S,), (M,)), 'coef', (), (), (), ident, {'beta': 0.2, 'alpha': 0.6}), | ||
| ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), (), ident, {'beta': 0.2, 'alpha': 0.6}), | ||
| ('dot', (L,), ((L,),), '', (True,)), | ||
| ('vdot', (L,), ((L,),), '', (True,)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be (False,) at the end (to turn off JIT autodiff testing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, ('vdot', (L,), ((L,),), does the trick
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc. @eellison torch.vdot is not supported by JIT autodiff right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The question here is: Is it OK that autodiff doesn't support vdot? It's a new operator we added recently. Also, does the JIT still use the old autodiff pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be fine if it's not supported by JIT autodiff - having an autodiff is an optimization which allows fusion to occur in training.
Yea it still uses the old autodiff pass. The way to define a backwards is in the symbolic_script.cpp file. Currently, there is really only autodiff coverage for pointwise ops because those are the ops that we codegen fusion for.
…inition for torch.vdot" TODO: Add R -> C tests in #44744 (blocked on some JIT changes) [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/anjali411/57/base #45074 +/- ##
=======================================================
Coverage ? 67.85%
=======================================================
Files ? 384
Lines ? 50020
Branches ? 0
=======================================================
Hits ? 33940
Misses ? 16080
Partials ? 0 Continue to review full report at Codecov.
|
| Tensor correct_dtype_gradients(ScalarType self_st, Tensor gradient_result) { | ||
| if (!at::isComplexType(self_st) && gradient_result.is_complex()) { | ||
| // R -> C | ||
| return at::real(gradient_result); | ||
| } | ||
| return gradient_result; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"correct_dtype_gradients" seems like a very generic name (e.g. it could be mistaken for a function that handles float -> double dtype conversion). Also, it looks like this function will be used quite a lot in autograd formulas.
Tossing some quick ideas out there:
- "handle_r_to_c", "handle_real_to_complex"
- "maybe_real_part"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that at::real should noop on real tensors, similar to how at::conj is noop on real tensors too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 makes sense.handle_real_to_complex should be clearer, I think.
@ezyang yeah I agree. The reason we disabled at::real for non-complex tensors before was because it would be weird to have real return a view for non-complex tensors and at::imag return a new tensor populated with zeros (which it was before). We can certainly only enable at::real for non-complex tensors if we want though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but now that at::conj is a no-op for real tensors, I think we should probably be OK with making at::real do the same as well. I don't care... too much about imag, I don't think it shows up in situations like this.
|
Besides Richard's comment, rest of the PR looks reasonable. |
…inition for torch.vdot" TODO: Add R -> C tests in #44744 (blocked on some JIT changes) [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after the name change. The discussion on at::real being a no-op for real tensors seems orthogonal so I am approving to unblock.
…inition for torch.vdot" TODO: Add R -> C tests in #44744 (blocked on some JIT changes) Differential Revision: [D23975361](https://our.internmc.facebook.com/intern/diff/D23975361) [ghstack-poisoned]
|
@anjali411 merged this pull request in 18876b5. |
Stack from ghstack:
TODO: Add R -> C tests in #44744 (blocked on some JIT changes)
Differential Revision: D23975361