Skip to content

Commit be0cb21

Browse files
committed
complex autograd doc
ghstack-source-id: 8fd0118 Pull Request resolved: #41012
1 parent 0edbe6b commit be0cb21

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

docs/source/notes/autograd.rst

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,108 @@ No thread safety on C++ hooks
210210
Autograd relies on the user to write thread safe C++ hooks. If you want the hook
211211
to be correctly applied in multithreading environment, you will need to write
212212
proper thread locking code to ensure the hooks are thread safe.
213+
214+
Autograd for Complex Numbers
215+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216+
217+
**What notion of complex derivative does PyTorch use?**
218+
*******************************************************
219+
220+
PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
221+
convention for autograd for Complex Numbers.
222+
223+
Suppose we have a function :math:`F: ℂ → ℂ` which we can decompose into functions u and v
224+
which compute the real and imaginary parts of the function:
225+
226+
.. code::
227+
228+
def F(z):
229+
x, y = real(z), imag(z)
230+
return u(x, y) + v(x, y) * 1j
231+
232+
where *1j* is a unit imaginary number.
233+
234+
The JVP and VJP for function :math:`F` at :math:`(x, y)` are defined as:
235+
236+
.. code::
237+
238+
def JVP(tangent):
239+
c, d = real(tangent), imag(tangent)
240+
return [1, 1j]^T * J * [c, d]
241+
242+
.. code::
243+
244+
def VJP(cotangent):
245+
c, d = real(cotangent), imag(cotangent)
246+
return [c, -d]^T * J * [1, -1j]
247+
248+
where
249+
250+
.. math::
251+
252+
J = \begin{bmatrix}
253+
\partial_0u(x, y) & \partial_1u(x, y)\\
254+
\partial_0v(x, y) & \partial_1v(x, y) \end{bmatrix} \\
255+
256+
In PyTorch, the VJP is mostly what we care about, as it is the computation performed when we do backward
257+
mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above.
258+
259+
**Why is there a negative sign in the formula above?**
260+
******************************************************
261+
262+
For a function F: V → W, where are V and W are vector spaces. The output of
263+
the Vector-Jacobian Product :math:`VJP : V → (W^* → V^*)` is a linear map
264+
from :math:`W^* → V^*` (explained in `Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf>`_).
265+
266+
The negative signs in the above `VJP` computation are due to conjugation. The first
267+
vector in the output returned by `VJP` for a given cotangent is a covector (:math:`\in ℂ^*`),
268+
and the last vector in the output is used to get the result in :math:`ℂ`
269+
since the final result of reverse-mode differentiation of a function is a covector belonging
270+
to :math:`ℂ^*` (explained in `Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf>`_).
271+
272+
**What happens if I call backward() on a complex scalar?**
273+
*******************************************************************************
274+
275+
The gradient for a complex function is computed assuming the input function is a holomorphic function.
276+
This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
277+
(as in the `2x2` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
278+
However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
279+
Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate
280+
matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
281+
obtain that gradient using backward which is just a call to `vjp` with covector `1.0`.
282+
283+
The net effect of this assumption is that the partial derivatives of the imaginary part of the function
284+
(:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar
285+
(e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).
286+
287+
For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly.
288+
289+
**How are the JVP and VJP defined for cross-domain functions?**
290+
***************************************************************
291+
292+
Based on formulas above and the behavior we expect to see (going from :math:`ℂ → ℝ^2 → ℂ` should be an identity),
293+
we use the following formula for cross-domain functions.
294+
295+
The JVP and VJP for a :math:`f1: ℂ → ℝ^2` are defined as:
296+
297+
.. code::
298+
299+
def JVP(tangent):
300+
c, d = real(tangent), imag(tangent)
301+
return J * [c, d]
302+
303+
def VJP(cotangent):
304+
c, d = real(cotangent), imag(cotangent)
305+
return [c, d]^T * J * [1, -1j]
306+
307+
The JVP and VJP for a :math:`f1: ℝ^2 → ℂ` are defined as:
308+
309+
.. code::
310+
311+
def JVP(tangent):
312+
c, d = real(tangent), imag(tangent)
313+
return [1, 1j]^T * J * [c, d]
314+
315+
def VJP(cotangent):
316+
c, d = real(cotangent), imag(cotangent)
317+
return [c, -d]^T * J

0 commit comments

Comments
 (0)