-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Autograd Doc for Complex Numbers #41012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
842d87e
dbf9280
0c138b7
f0f6423
e0a5814
f0729f7
29b0167
23d1fef
7188840
f463d37
29904c1
fa8e665
2e67418
ae4edc3
320188f
8b7e3e3
2db4435
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -210,3 +210,82 @@ No thread safety on C++ hooks | |
| Autograd relies on the user to write thread safe C++ hooks. If you want the hook | ||
| to be correctly applied in multithreading environment, you will need to write | ||
| proper thread locking code to ensure the hooks are thread safe. | ||
|
|
||
| Autograd for Complex Numbers | ||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
|
||
| **What notion of complex derivative does PyTorch use?** | ||
| ******************************************************* | ||
|
|
||
| PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_ | ||
| convention for autograd for Complex Numbers. | ||
|
|
||
| Suppose we have a function :math:`F: ℂ → ℂ` which we can decompose into functions u and v | ||
| which compute the real and imaginary parts of the function: | ||
|
|
||
| .. code:: | ||
|
|
||
| def F(z): | ||
| x, y = real(z), imag(z) | ||
| return u(x, y) + v(x, y) * 1j | ||
|
|
||
| where :math:`1j` is a unit imaginary number. | ||
|
|
||
| We define the :math:`JVP` for function :math:`F` at :math:`(x, y)` applied to a tangent | ||
| vector :math:`c+dj \in C` as: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's probably better to explicitly say that this is Python pseudocode. For example, you couldn't actually use Or even better, make this real code using the real PyTorch operations. Then you don't have to explain the pseudocode syntax.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated this section to just use math notation instead. It's simpler to read and avoids the confusion |
||
|
|
||
| .. math:: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} | ||
|
|
||
| where | ||
|
|
||
| .. math:: | ||
| J = \begin{bmatrix} | ||
| \frac{\partial u(x, y)}{\partial x} & \frac{\partial u(x, y)}{\partial y}\\ | ||
| \frac{\partial v(x, y)}{\partial x} & \frac{\partial v(x, y)}{\partial y} \end{bmatrix} \\ | ||
|
|
||
| This is similar to the definition of the JVP for a function defined from :math:`R^2 → R^2`, and the multiplication | ||
| with :math:`[1, 1j]^T` is used to identify the result as a complex number. | ||
|
|
||
| We define the :math:`VJP` of :math:`F` at :math:`(x, y)` for a cotangent vector :math:`c+dj \in C` as: | ||
|
|
||
| .. math:: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix} | ||
|
|
||
| In PyTorch, the `VJP` is mostly what we care about, as it is the computation performed when we do backward | ||
| mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above. Please look at | ||
| the `JAX docs <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_ | ||
| to get explanation for the negative signs in the formula. | ||
|
|
||
| **What happens if I call backward() on a complex scalar?** | ||
| ******************************************************************************* | ||
|
|
||
| The gradient for a complex function is computed assuming the input function is a holomorphic function. | ||
| This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reading more about this, it feels like the "pure" C function definition does not hold.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean by "pure" C function definition does not hold?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I call "pure C function definition" here is what happens if you apply the derivation definition (as a limit) from real functions to complex functions: you get derivatives only for holomorphic functions. |
||
| (as in the `2x2` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number. | ||
| However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the | ||
| Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate | ||
| matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can | ||
| obtain that gradient using backward which is just a call to `vjp` with covector `1.0`. | ||
|
|
||
| The net effect of this assumption is that the partial derivatives of the imaginary part of the function | ||
| (:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar | ||
| (e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards). | ||
|
|
||
| For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly. | ||
|
|
||
| **How are the JVP and VJP defined for cross-domain functions?** | ||
| *************************************************************** | ||
|
|
||
| Based on formulas above and the behavior we expect to see (going from :math:`ℂ → ℝ^2 → ℂ` should be an identity), | ||
| we use the formula given below for cross-domain functions. | ||
|
|
||
| The :math:`JVP` and :math:`VJP` for a :math:`f1: ℂ → ℝ^2` are defined as: | ||
|
|
||
| .. math:: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix} | ||
|
|
||
| .. math:: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix} | ||
|
|
||
| The :math:`JVP` and :math:`VJP` for a :math:`f1: ℝ^2 → ℂ` are defined as: | ||
|
|
||
| .. math:: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\ | ||
|
|
||
| .. math:: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J | ||
Uh oh!
There was an error while loading. Please reload this page.