Skip to content

Commit 488d436

Browse files
committed
complex autograd doc
ghstack-source-id: ea1417c Pull Request resolved: #41012
1 parent 0edbe6b commit 488d436

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

docs/source/notes/autograd.rst

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,116 @@ No thread safety on C++ hooks
210210
Autograd relies on the user to write thread safe C++ hooks. If you want the hook
211211
to be correctly applied in multithreading environment, you will need to write
212212
proper thread locking code to ensure the hooks are thread safe.
213+
214+
Autograd for Complex Numbers
215+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216+
217+
**What notion of complex derivative does PyTorch use?**
218+
*******************************************************
219+
220+
PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
221+
convention for autograd for Complex Numbers.
222+
223+
Suppose we have a function :math:`F: ℂ → ℂ` which we can decompose into functions u and v
224+
which compute the real and imaginary parts of the function:
225+
226+
.. code::
227+
228+
def F(z):
229+
x, y = real(z), imag(z)
230+
return u(x, y) + v(x, y) * 1j
231+
232+
where :math:`1j` is a unit imaginary number.
233+
234+
We define the JVP for function :math:`F` at :math:`(x, y)` applied to a tangent
235+
vector :math:`c+dj \in C` as:
236+
237+
.. code::
238+
239+
def JVP(tangent):
240+
c, d = real(tangent), imag(tangent)
241+
return [1, 1j]^T * J * [c, d]
242+
243+
where
244+
245+
.. math::
246+
247+
J = \begin{bmatrix}
248+
\frac{\partial u(x, y)}{\partial x} & \frac{\partial u(x, y)}{\partial y}\\
249+
\frac{\partial v(x, y)}{\partial x} & \frac{\partial v(x, y)}{\partial y} \end{bmatrix} \\
250+
251+
This is similar to the definition of the JVP for a function defined from :math:`R^2 → R^2`, and the multiplication
252+
with :math:`[1, 1j]^T` is used to identify the result as a complex number.
253+
254+
We define the VJP of :math:`F` at :math:`(x, y)` for a cotangent vector :math:`c+dj \in C` as:
255+
256+
.. code::
257+
258+
def VJP(cotangent):
259+
c, d = real(cotangent), imag(cotangent)
260+
return [c, -d]^T * J * [1, -1j]
261+
262+
In PyTorch, the VJP is mostly what we care about, as it is the computation performed when we do backward
263+
mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above.
264+
265+
**Why is there a negative sign in the formula above?**
266+
******************************************************
267+
268+
For a function F: V → W, where are V and W are vector spaces. The output of
269+
the Vector-Jacobian Product :math:`VJP : V → (W^* → V^*)` is a linear map
270+
from :math:`W^* → V^*` (explained in `Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf>`_).
271+
272+
The negative signs in the above `VJP` computation are due to conjugation. :math:`c-dj`
273+
is the covector in dual space of :math:`C^` (:math:`\in ℂ^*`) corresponding to the
274+
cotangent vector :math:`c+dj`, and the multiplication by :math:`[1, -1j]`, whose net effect is
275+
to get a conjugate of the complex number we would have obtained by multiplcation with :math:`[1, 1j]` instead,
276+
is used to get the result in :math:`ℂ` since the final result of reverse-mode differentiation of a function
277+
is a covector belonging to :math:`ℂ^*` (explained in
278+
`Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf>`_).
279+
280+
**What happens if I call backward() on a complex scalar?**
281+
*******************************************************************************
282+
283+
The gradient for a complex function is computed assuming the input function is a holomorphic function.
284+
This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
285+
(as in the `2x2` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
286+
However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
287+
Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate
288+
matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
289+
obtain that gradient using backward which is just a call to `vjp` with covector `1.0`.
290+
291+
The net effect of this assumption is that the partial derivatives of the imaginary part of the function
292+
(:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar
293+
(e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).
294+
295+
For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly.
296+
297+
**How are the JVP and VJP defined for cross-domain functions?**
298+
***************************************************************
299+
300+
Based on formulas above and the behavior we expect to see (going from :math:`ℂ → ℝ^2 → ℂ` should be an identity),
301+
we use the following formula for cross-domain functions.
302+
303+
The JVP and VJP for a :math:`f1: ℂ → ℝ^2` are defined as:
304+
305+
.. code::
306+
307+
def JVP(tangent):
308+
c, d = real(tangent), imag(tangent)
309+
return J * [c, d]
310+
311+
def VJP(cotangent):
312+
c, d = real(cotangent), imag(cotangent)
313+
return [c, d]^T * J * [1, -1j]
314+
315+
The JVP and VJP for a :math:`f1: ℝ^2 → ℂ` are defined as:
316+
317+
.. code::
318+
319+
def JVP(tangent):
320+
c, d = real(tangent), imag(tangent)
321+
return [1, 1j]^T * J * [c, d]
322+
323+
def VJP(cotangent):
324+
c, d = real(cotangent), imag(cotangent)
325+
return [c, -d]^T * J

0 commit comments

Comments
 (0)