@@ -210,3 +210,108 @@ No thread safety on C++ hooks
210210Autograd relies on the user to write thread safe C++ hooks. If you want the hook
211211to be correctly applied in multithreading environment, you will need to write
212212proper thread locking code to ensure the hooks are thread safe.
213+
214+ Autograd for Complex Numbers
215+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216+
217+ **What notion of complex derivative does PyTorch use? **
218+ *******************************************************
219+
220+ PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation >`_
221+ convention for autograd for Complex Numbers.
222+
223+ Suppose we have a function :math: `F: ℂ → ℂ` which we can decompose into functions u and v
224+ which compute the real and imaginary parts of the function:
225+
226+ .. code ::
227+
228+ def F(z):
229+ x, y = real(z), imag(z)
230+ return u(x, y) + v(x, y) * 1j
231+
232+ where *1j * is a unit imaginary number.
233+
234+ The JVP and VJP for function :math: `F` at :math: `(x, y)` are defined as:
235+
236+ .. code ::
237+
238+ def JVP(tangent):
239+ c, d = real(tangent), imag(tangent)
240+ return [1, 1j]^T * J * [c, d]
241+
242+ .. code ::
243+
244+ def VJP(cotangent):
245+ c, d = real(cotangent), imag(cotangent)
246+ return [c, -d]^T * J * [1, -1j]
247+
248+ where
249+
250+ .. math ::
251+
252+ J = \begin {bmatrix}
253+ \partial _0 u(x, y) & \partial _1 u(x, y)\\
254+ \partial _0 v(x, y) & \partial _1 v(x, y) \end {bmatrix} \\
255+
256+ In PyTorch, the VJP is mostly what we care about, as it is the computation performed when we do backward
257+ mode automatic differentiation. Notice that d and :math: `1 j` are negated in the formula above.
258+
259+ **Why is there a negative sign in the formula above? **
260+ ******************************************************
261+
262+ For a function F: V → W, where are V and W are vector spaces. The output of
263+ the Vector-Jacobian Product :math: `VJP : V → (W^* → V^*)` is a linear map
264+ from :math: `W^* → V^*` (explained in `Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf >`_).
265+
266+ The negative signs in the above `VJP ` computation are due to conjugation. The first
267+ vector in the output returned by `VJP ` for a given cotangent is a covector (:math: `\in ℂ^*`),
268+ and the last vector in the output is used to get the result in :math: `ℂ`
269+ since the final result of reverse-mode differentiation of a function is a covector belonging
270+ to :math: `ℂ^*` (explained in `Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf >`_).
271+
272+ **What happens if I call backward() on a complex scalar? **
273+ *******************************************************************************
274+
275+ The gradient for a complex function is computed assuming the input function is a holomorphic function.
276+ This is because for general :math: `ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
277+ (as in the `2x2 ` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
278+ However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
279+ Cauchy-Riemann equations that ensure that `2x2 ` Jacobians have the special form of a scale-and-rotate
280+ matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
281+ obtain that gradient using backward which is just a call to `vjp ` with covector `1.0 `.
282+
283+ The net effect of this assumption is that the partial derivatives of the imaginary part of the function
284+ (:math: `v(x, y)` above) are discarded for :func: `torch.autograd.backward ` on a complex scalar
285+ (e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).
286+
287+ For any other desired behavior, you can specify the covector `grad_output ` in :func: `torch.autograd.backward ` call accordingly.
288+
289+ **How are the JVP and VJP defined for cross-domain functions? **
290+ ***************************************************************
291+
292+ Based on formulas above and the behavior we expect to see (going from :math: `ℂ → ℝ^2 → ℂ` should be an identity),
293+ we use the following formula for cross-domain functions.
294+
295+ The JVP and VJP for a :math: `f1 : ℂ → ℝ^2 ` are defined as:
296+
297+ .. code ::
298+
299+ def JVP(tangent):
300+ c, d = real(tangent), imag(tangent)
301+ return J * [c, d]
302+
303+ def VJP(cotangent):
304+ c, d = real(cotangent), imag(cotangent)
305+ return [c, d]^T * J * [1, -1j]
306+
307+ The JVP and VJP for a :math: `f1 : ℝ^2 → ℂ` are defined as:
308+
309+ .. code ::
310+
311+ def JVP(tangent):
312+ c, d = real(tangent), imag(tangent)
313+ return [1, 1j]^T * J * [c, d]
314+
315+ def VJP(cotangent):
316+ c, d = real(cotangent), imag(cotangent)
317+ return [c, -d]^T * J
0 commit comments