Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions docs/source/notes/autograd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,82 @@ No thread safety on C++ hooks
Autograd relies on the user to write thread safe C++ hooks. If you want the hook
to be correctly applied in multithreading environment, you will need to write
proper thread locking code to ensure the hooks are thread safe.

Autograd for Complex Numbers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

**What notion of complex derivative does PyTorch use?**
*******************************************************

PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
convention for autograd for Complex Numbers.

Suppose we have a function :math:`F: ℂ → ℂ` which we can decompose into functions u and v
which compute the real and imaginary parts of the function:

.. code::

def F(z):
x, y = real(z), imag(z)
return u(x, y) + v(x, y) * 1j

where :math:`1j` is a unit imaginary number.

We define the :math:`JVP` for function :math:`F` at :math:`(x, y)` applied to a tangent
vector :math:`c+dj \in C` as:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's probably better to explicitly say that this is Python pseudocode. For example, you couldn't actually use * in a real PyTorch program as that would give you pointwise multiplication, not matrix product. And the bracket syntax means something else in Python, that is also not intended here either.

Or even better, make this real code using the real PyTorch operations. Then you don't have to explain the pseudocode syntax.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this section to just use math notation instead. It's simpler to read and avoids the confusion


.. math:: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix}

where

.. math::
J = \begin{bmatrix}
\frac{\partial u(x, y)}{\partial x} & \frac{\partial u(x, y)}{\partial y}\\
\frac{\partial v(x, y)}{\partial x} & \frac{\partial v(x, y)}{\partial y} \end{bmatrix} \\

This is similar to the definition of the JVP for a function defined from :math:`R^2 → R^2`, and the multiplication
with :math:`[1, 1j]^T` is used to identify the result as a complex number.

We define the :math:`VJP` of :math:`F` at :math:`(x, y)` for a cotangent vector :math:`c+dj \in C` as:

.. math:: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}

In PyTorch, the `VJP` is mostly what we care about, as it is the computation performed when we do backward
mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above. Please look at
the `JAX docs <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
to get explanation for the negative signs in the formula.

**What happens if I call backward() on a complex scalar?**
*******************************************************************************

The gradient for a complex function is computed assuming the input function is a holomorphic function.
This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading more about this, it feels like the "pure" C function definition does not hold.
And to be able to get quantities that match the gradients for holomorphic functions and provide sensible direction for use in gradient descent algorithms, modified definitions are needed.
Unfortunately, these definitions represent the full information about the derivatives of a general complex function using twice as many elements as a regular R -> R gradient we are used to.
And so we cannot expect to get these "extended gradients" in our framework where the .grad field is hard-coded to be the same size as the Tensor it represent the gradients of.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "pure" C function definition does not hold?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I call "pure C function definition" here is what happens if you apply the derivation definition (as a limit) from real functions to complex functions: you get derivatives only for holomorphic functions.

(as in the `2x2` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate
matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
obtain that gradient using backward which is just a call to `vjp` with covector `1.0`.

The net effect of this assumption is that the partial derivatives of the imaginary part of the function
(:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar
(e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).

For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly.

**How are the JVP and VJP defined for cross-domain functions?**
***************************************************************

Based on formulas above and the behavior we expect to see (going from :math:`ℂ → ℝ^2 → ℂ` should be an identity),
we use the formula given below for cross-domain functions.

The :math:`JVP` and :math:`VJP` for a :math:`f1: ℂ → ℝ^2` are defined as:

.. math:: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix}

.. math:: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}

The :math:`JVP` and :math:`VJP` for a :math:`f1: ℝ^2 → ℂ` are defined as:

.. math:: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\

.. math:: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J