-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Doc note update for complex autograd #45270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
|
cc. @boeddeker |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good for me.
Just small comments.
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 851ca95 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 21 times. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
docs/source/notes/autograd.rst
Outdated
| def F(z): | ||
| x, y = real(z), imag(z) | ||
| return u(x, y) + v(x, y) * 1j | ||
| s = u(x, y) + v(x, y) * 1j | ||
| return s | ||
| where :math:`1j` is a unit imaginary number. | ||
| where :math:`1j` is a unit imaginary number and :math:`u` and :math:`v` are :math:`ℝ^{2} → ℝ` functions. | ||
|
|
||
| We define the :math:`JVP` for function :math:`F` at :math:`(x, y)` applied to a tangent | ||
| vector :math:`c+dj \in C` as: | ||
| We define the Vector-Jacobian Product `VJP` of :math:`F` at :math:`x + yj` for grad output vector :math:`grad\_out` as: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a complex numbers n00b, I think this documentation is a great step to helping developers understand how to write (or update) gradient formulas to support complex numbers.
I think something else that would be helpful is if there was an Appendix somewhere that stepped through the derivation of some operator (like torch.mul, or torch.dot). When I started out understanding how complex number gradients worked, I tried to derive some formulas (e.g., for torch.mul) and compare them against the existing gradient formulas.
Perhaps that would be better for developer documentation though -- I'm not sure if the purpose of notes/autograd is for users to look at and understand what's going on, or for developers to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 thanks for the feedback! I added the screenshots in the description as per your suggestion. I agree that it would be useful to step through the derivation for one of the operators. However, I am not sure if this doc note is the right place to do that. We can possibly include a gradient derivation for one of the operators here: https://pytorch.org/docs/stable/complex_numbers.html#autograd along with an example of how the gradients can be used in an optimization problem.
[ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/anjali411/58/base #45270 +/- ##
=====================================================
Coverage 68.25% 68.25%
=====================================================
Files 410 410
Lines 53232 53232
=====================================================
+ Hits 36335 36336 +1
+ Misses 16897 16896 -1
Continue to review full report at Codecov.
|
docs/source/notes/autograd.rst
Outdated
| mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above. Please look at | ||
| the `JAX docs <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_ | ||
| to get explanation for the negative signs in the formula. | ||
| The above formula can also be verified to be correct for cross-domain functions functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: One functions too much in functions functions.
docs/source/notes/autograd.rst
Outdated
| x, y = real(z), imag(z) | ||
| return u(x, y) + v(x, y) * 1j | ||
| s = u(x, y) + v(x, y) * 1j | ||
| return s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an idea. You start here directly with the assumption, that the chain rule is known.
Maybe it would be better for a beginner, when you start with the core assumption:
e.g.: Interpreted the complex number as two numbers and calculate the derivative:
grad_x = dG / dx where G \in \mathcal{R}
grad_y = dG / dy where G \in \mathcal{R}
Since complex numbers can contain 2 real numbers, we can define:
grad = grad_x + 1j * grad_y
or
grad = grad_x - 1j * grad_y
And this is known as Wirtinger Calculus.
(Note: Dropping the 1/2 in eq. (3.2) of https://arxiv.org/pdf/1701.00392.pdf is usually what NNs do. At the time of writing the paper we weren't careful, because this equation is never implemented)
docs/source/notes/autograd.rst
Outdated
| return s | ||
| The :math:`JVP` and :math:`VJP` for a :math:`f1: ℝ^2 → ℂ` are defined as: | ||
| at :math:`x + yj` can be simplified to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is always difficult to describe. But you cannot evaluate it at x + yj, because z is real valued.
Or should this operation be C -> R -> C -> R i.e. G(F(real_to_complex(real(z))))) (Not sure, if this matches your notation. Probably is G in your case the complete function) ?
|
For the example of the complex derivation, it would be nice to extend the example with PyTorch code. |
<img width="1672" alt="Screen Shot 2020-09-24 at 1 09 05 PM" src="https://user-images.githubusercontent.com/20081078/94177334-5f2b9600-fe67-11ea-97cf-b2949153a117.png"> <img width="1673" alt="Screen Shot 2020-09-24 at 1 09 16 PM" src="https://user-images.githubusercontent.com/20081078/94177346-62bf1d00-fe67-11ea-94cc-d429b55a73b1.png"> [ghstack-poisoned]
@boeddeker - @ezyang and I rewrote the doc note and incorporated more background knowledge for complex derivatives in the updated version. you can check out how the html rendering in the PR description. |
<img width="1679" alt="Screen Shot 2020-10-07 at 1 45 59 PM" src="https://user-images.githubusercontent.com/20081078/95368324-fa7b2d00-08a3-11eb-9066-2e659a4085a2.png"> <img width="1673" alt="Screen Shot 2020-10-07 at 1 46 10 PM" src="https://user-images.githubusercontent.com/20081078/95368332-fbac5a00-08a3-11eb-9be5-77ce6deb8967.png"> <img width="1667" alt="Screen Shot 2020-10-07 at 1 46 30 PM" src="https://user-images.githubusercontent.com/20081078/95368337-fe0eb400-08a3-11eb-80a2-5ad23feeeb83.png"> <img width="1679" alt="Screen Shot 2020-10-07 at 1 46 48 PM" src="https://user-images.githubusercontent.com/20081078/95368345-00710e00-08a4-11eb-96d9-e2d544554a4b.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 03 PM" src="https://user-images.githubusercontent.com/20081078/95368350-023ad180-08a4-11eb-89b3-f079480741f4.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 12 PM" src="https://user-images.githubusercontent.com/20081078/95368364-0535c200-08a4-11eb-82fc-9435a046e4ca.png"> [ghstack-poisoned]
<img width="1679" alt="Screen Shot 2020-10-07 at 1 45 59 PM" src="https://user-images.githubusercontent.com/20081078/95368324-fa7b2d00-08a3-11eb-9066-2e659a4085a2.png"> <img width="1673" alt="Screen Shot 2020-10-07 at 1 46 10 PM" src="https://user-images.githubusercontent.com/20081078/95368332-fbac5a00-08a3-11eb-9be5-77ce6deb8967.png"> <img width="1667" alt="Screen Shot 2020-10-07 at 1 46 30 PM" src="https://user-images.githubusercontent.com/20081078/95368337-fe0eb400-08a3-11eb-80a2-5ad23feeeb83.png"> <img width="1679" alt="Screen Shot 2020-10-07 at 1 46 48 PM" src="https://user-images.githubusercontent.com/20081078/95368345-00710e00-08a4-11eb-96d9-e2d544554a4b.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 03 PM" src="https://user-images.githubusercontent.com/20081078/95368350-023ad180-08a4-11eb-89b3-f079480741f4.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 12 PM" src="https://user-images.githubusercontent.com/20081078/95368364-0535c200-08a4-11eb-82fc-9435a046e4ca.png"> [ghstack-poisoned]
docs/source/notes/autograd.rst
Outdated
| convention for autograd for Complex Numbers. | ||
| - PyTorch Autograd returns gradients that are directly usable for optimization and gradient descent (no conjugation needed), | ||
| so existing optimizers work out of the box. We don’t make any attempt to support holomorphic functions specially; | ||
| in particular, derivatives are only well defined for real-valued loss functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize this is my own damn fault because I wrote this text in the first place, but taken literally this text is wrong (I'm not sure what I even meant when I wrote this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm I rephrased it. does it look any better?
<img width="1679" alt="Screen Shot 2020-10-07 at 1 45 59 PM" src="https://user-images.githubusercontent.com/20081078/95368324-fa7b2d00-08a3-11eb-9066-2e659a4085a2.png"> <img width="1673" alt="Screen Shot 2020-10-07 at 1 46 10 PM" src="https://user-images.githubusercontent.com/20081078/95368332-fbac5a00-08a3-11eb-9be5-77ce6deb8967.png"> <img width="1667" alt="Screen Shot 2020-10-07 at 1 46 30 PM" src="https://user-images.githubusercontent.com/20081078/95368337-fe0eb400-08a3-11eb-80a2-5ad23feeeb83.png"> <img width="1679" alt="Screen Shot 2020-10-07 at 1 46 48 PM" src="https://user-images.githubusercontent.com/20081078/95368345-00710e00-08a4-11eb-96d9-e2d544554a4b.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 03 PM" src="https://user-images.githubusercontent.com/20081078/95368350-023ad180-08a4-11eb-89b3-f079480741f4.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 12 PM" src="https://user-images.githubusercontent.com/20081078/95368364-0535c200-08a4-11eb-82fc-9435a046e4ca.png"> [ghstack-poisoned]
<img width="1679" alt="Screen Shot 2020-10-07 at 1 45 59 PM" src="https://user-images.githubusercontent.com/20081078/95368324-fa7b2d00-08a3-11eb-9066-2e659a4085a2.png"> <img width="1673" alt="Screen Shot 2020-10-07 at 1 46 10 PM" src="https://user-images.githubusercontent.com/20081078/95368332-fbac5a00-08a3-11eb-9be5-77ce6deb8967.png"> <img width="1667" alt="Screen Shot 2020-10-07 at 1 46 30 PM" src="https://user-images.githubusercontent.com/20081078/95368337-fe0eb400-08a3-11eb-80a2-5ad23feeeb83.png"> <img width="1679" alt="Screen Shot 2020-10-07 at 1 46 48 PM" src="https://user-images.githubusercontent.com/20081078/95368345-00710e00-08a4-11eb-96d9-e2d544554a4b.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 03 PM" src="https://user-images.githubusercontent.com/20081078/95368350-023ad180-08a4-11eb-89b3-f079480741f4.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 12 PM" src="https://user-images.githubusercontent.com/20081078/95368364-0535c200-08a4-11eb-82fc-9435a046e4ca.png"> [ghstack-poisoned]
<img width="1679" alt="Screen Shot 2020-10-07 at 1 45 59 PM" src="https://user-images.githubusercontent.com/20081078/95368324-fa7b2d00-08a3-11eb-9066-2e659a4085a2.png"> <img width="1673" alt="Screen Shot 2020-10-07 at 1 46 10 PM" src="https://user-images.githubusercontent.com/20081078/95368332-fbac5a00-08a3-11eb-9be5-77ce6deb8967.png"> <img width="1667" alt="Screen Shot 2020-10-07 at 1 46 30 PM" src="https://user-images.githubusercontent.com/20081078/95368337-fe0eb400-08a3-11eb-80a2-5ad23feeeb83.png"> <img width="1679" alt="Screen Shot 2020-10-07 at 1 46 48 PM" src="https://user-images.githubusercontent.com/20081078/95368345-00710e00-08a4-11eb-96d9-e2d544554a4b.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 03 PM" src="https://user-images.githubusercontent.com/20081078/95368350-023ad180-08a4-11eb-89b3-f079480741f4.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 12 PM" src="https://user-images.githubusercontent.com/20081078/95368364-0535c200-08a4-11eb-82fc-9435a046e4ca.png"> [ghstack-poisoned]
|
@anjali411 merged this pull request in 8925661. |
Summary: Pull Request resolved: pytorch#45270 <img width="1679" alt="Screen Shot 2020-10-07 at 1 45 59 PM" src="https://user-images.githubusercontent.com/20081078/95368324-fa7b2d00-08a3-11eb-9066-2e659a4085a2.png"> <img width="1673" alt="Screen Shot 2020-10-07 at 1 46 10 PM" src="https://user-images.githubusercontent.com/20081078/95368332-fbac5a00-08a3-11eb-9be5-77ce6deb8967.png"> <img width="1667" alt="Screen Shot 2020-10-07 at 1 46 30 PM" src="https://user-images.githubusercontent.com/20081078/95368337-fe0eb400-08a3-11eb-80a2-5ad23feeeb83.png"> <img width="1679" alt="Screen Shot 2020-10-07 at 1 46 48 PM" src="https://user-images.githubusercontent.com/20081078/95368345-00710e00-08a4-11eb-96d9-e2d544554a4b.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 03 PM" src="https://user-images.githubusercontent.com/20081078/95368350-023ad180-08a4-11eb-89b3-f079480741f4.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 12 PM" src="https://user-images.githubusercontent.com/20081078/95368364-0535c200-08a4-11eb-82fc-9435a046e4ca.png"> Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D24203257 Pulled By: anjali411 fbshipit-source-id: cd637dade5fb40cecf5d9f4bd03d508d36e26fcd
Summary: Pull Request resolved: #45270 <img width="1679" alt="Screen Shot 2020-10-07 at 1 45 59 PM" src="https://user-images.githubusercontent.com/20081078/95368324-fa7b2d00-08a3-11eb-9066-2e659a4085a2.png"> <img width="1673" alt="Screen Shot 2020-10-07 at 1 46 10 PM" src="https://user-images.githubusercontent.com/20081078/95368332-fbac5a00-08a3-11eb-9be5-77ce6deb8967.png"> <img width="1667" alt="Screen Shot 2020-10-07 at 1 46 30 PM" src="https://user-images.githubusercontent.com/20081078/95368337-fe0eb400-08a3-11eb-80a2-5ad23feeeb83.png"> <img width="1679" alt="Screen Shot 2020-10-07 at 1 46 48 PM" src="https://user-images.githubusercontent.com/20081078/95368345-00710e00-08a4-11eb-96d9-e2d544554a4b.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 03 PM" src="https://user-images.githubusercontent.com/20081078/95368350-023ad180-08a4-11eb-89b3-f079480741f4.png"> <img width="1680" alt="Screen Shot 2020-10-07 at 1 47 12 PM" src="https://user-images.githubusercontent.com/20081078/95368364-0535c200-08a4-11eb-82fc-9435a046e4ca.png"> Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D24203257 Pulled By: anjali411 fbshipit-source-id: cd637dade5fb40cecf5d9f4bd03d508d36e26fcd Co-authored-by: anjali411 <chourdiaanjali123@gmail.com>
Stack from ghstack:
Differential Revision: D24203257