@@ -210,3 +210,116 @@ No thread safety on C++ hooks
210210Autograd relies on the user to write thread safe C++ hooks. If you want the hook
211211to be correctly applied in multithreading environment, you will need to write
212212proper thread locking code to ensure the hooks are thread safe.
213+
214+ Autograd for Complex Numbers
215+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216+
217+ **What notion of complex derivative does PyTorch use? **
218+ *******************************************************
219+
220+ PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation >`_
221+ convention for autograd for Complex Numbers.
222+
223+ Suppose we have a function :math: `F: ℂ → ℂ` which we can decompose into functions u and v
224+ which compute the real and imaginary parts of the function:
225+
226+ .. code ::
227+
228+ def F(z):
229+ x, y = real(z), imag(z)
230+ return u(x, y) + v(x, y) * 1j
231+
232+ where :math: `1 j` is a unit imaginary number.
233+
234+ We define the JVP for function :math: `F` at :math: `(x, y)` applied to a tangent
235+ vector :math: `c+dj \in C` as:
236+
237+ .. code ::
238+
239+ def JVP(tangent):
240+ c, d = real(tangent), imag(tangent)
241+ return [1, 1j]^T * J * [c, d]
242+
243+ where
244+
245+ .. math ::
246+
247+ J = \begin {bmatrix}
248+ \frac {\partial u(x, y)}{\partial x} & \frac {\partial u(x, y)}{\partial y}\\
249+ \frac {\partial v(x, y)}{\partial x} & \frac {\partial v(x, y)}{\partial y} \end {bmatrix} \\
250+
251+ This is similar to the definition of the JVP for a function defined from :math: `R^2 → R^2 `, and the multiplication
252+ with :math: `[1 , 1 j]^T` is used to identify the result as a complex number.
253+
254+ We define the VJP of :math: `F` at :math: `(x, y)` for a cotangent vector :math: `c+dj \in C` as:
255+
256+ .. code ::
257+
258+ def VJP(cotangent):
259+ c, d = real(cotangent), imag(cotangent)
260+ return [c, -d]^T * J * [1, -1j]
261+
262+ In PyTorch, the VJP is mostly what we care about, as it is the computation performed when we do backward
263+ mode automatic differentiation. Notice that d and :math: `1 j` are negated in the formula above.
264+
265+ **Why is there a negative sign in the formula above? **
266+ ******************************************************
267+
268+ For a function F: V → W, where are V and W are vector spaces. The output of
269+ the Vector-Jacobian Product :math: `VJP : V → (W^* → V^*)` is a linear map
270+ from :math: `W^* → V^*` (explained in `Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf >`_).
271+
272+ The negative signs in the above `VJP ` computation are due to conjugation. :math: `c-dj`
273+ is the covector in dual space of :math: `C^` (:math: `\in ℂ^*`) corresponding to the
274+ cotangent vector :math: `c+dj`, and the multiplication by :math: `[1 , -1 j]`, whose net effect is
275+ to get a conjugate of the complex number we would have obtained by multiplcation with :math: `[1 , 1 j]` instead,
276+ is used to get the result in :math: `ℂ` since the final result of reverse-mode differentiation of a function
277+ is a covector belonging to :math: `ℂ^*` (explained in
278+ `Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf >`_).
279+
280+ **What happens if I call backward() on a complex scalar? **
281+ *******************************************************************************
282+
283+ The gradient for a complex function is computed assuming the input function is a holomorphic function.
284+ This is because for general :math: `ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
285+ (as in the `2x2 ` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
286+ However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
287+ Cauchy-Riemann equations that ensure that `2x2 ` Jacobians have the special form of a scale-and-rotate
288+ matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
289+ obtain that gradient using backward which is just a call to `vjp ` with covector `1.0 `.
290+
291+ The net effect of this assumption is that the partial derivatives of the imaginary part of the function
292+ (:math: `v(x, y)` above) are discarded for :func: `torch.autograd.backward ` on a complex scalar
293+ (e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).
294+
295+ For any other desired behavior, you can specify the covector `grad_output ` in :func: `torch.autograd.backward ` call accordingly.
296+
297+ **How are the JVP and VJP defined for cross-domain functions? **
298+ ***************************************************************
299+
300+ Based on formulas above and the behavior we expect to see (going from :math: `ℂ → ℝ^2 → ℂ` should be an identity),
301+ we use the following formula for cross-domain functions.
302+
303+ The JVP and VJP for a :math: `f1 : ℂ → ℝ^2 ` are defined as:
304+
305+ .. code ::
306+
307+ def JVP(tangent):
308+ c, d = real(tangent), imag(tangent)
309+ return J * [c, d]
310+
311+ def VJP(cotangent):
312+ c, d = real(cotangent), imag(cotangent)
313+ return [c, d]^T * J * [1, -1j]
314+
315+ The JVP and VJP for a :math: `f1 : ℝ^2 → ℂ` are defined as:
316+
317+ .. code ::
318+
319+ def JVP(tangent):
320+ c, d = real(tangent), imag(tangent)
321+ return [1, 1j]^T * J * [c, d]
322+
323+ def VJP(cotangent):
324+ c, d = real(cotangent), imag(cotangent)
325+ return [c, -d]^T * J
0 commit comments