@@ -210,3 +210,82 @@ No thread safety on C++ hooks
210210Autograd relies on the user to write thread safe C++ hooks. If you want the hook
211211to be correctly applied in multithreading environment, you will need to write
212212proper thread locking code to ensure the hooks are thread safe.
213+
214+ Autograd for Complex Numbers
215+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216+
217+ **What notion of complex derivative does PyTorch use? **
218+ *******************************************************
219+
220+ PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation >`_
221+ convention for autograd for Complex Numbers.
222+
223+ Suppose we have a function :math: `F: ℂ → ℂ` which we can decompose into functions u and v
224+ which compute the real and imaginary parts of the function:
225+
226+ .. code ::
227+
228+ def F(z):
229+ x, y = real(z), imag(z)
230+ return u(x, y) + v(x, y) * 1j
231+
232+ where :math: `1 j` is a unit imaginary number.
233+
234+ We define the :math: `JVP` for function :math: `F` at :math: `(x, y)` applied to a tangent
235+ vector :math: `c+dj \in C` as:
236+
237+ .. math :: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix}
238+
239+ where
240+
241+ .. math ::
242+ J = \begin {bmatrix}
243+ \frac {\partial u(x, y)}{\partial x} & \frac {\partial u(x, y)}{\partial y}\\
244+ \frac {\partial v(x, y)}{\partial x} & \frac {\partial v(x, y)}{\partial y} \end {bmatrix} \\
245+
246+ This is similar to the definition of the JVP for a function defined from :math: `R^2 → R^2 `, and the multiplication
247+ with :math: `[1 , 1 j]^T` is used to identify the result as a complex number.
248+
249+ We define the :math: `VJP` of :math: `F` at :math: `(x, y)` for a cotangent vector :math: `c+dj \in C` as:
250+
251+ .. math :: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}
252+
253+ In PyTorch, the `VJP ` is mostly what we care about, as it is the computation performed when we do backward
254+ mode automatic differentiation. Notice that d and :math: `1 j` are negated in the formula above. Please look at
255+ the `JAX docs <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation >`_
256+ to get explanation for the negative signs in the formula.
257+
258+ **What happens if I call backward() on a complex scalar? **
259+ *******************************************************************************
260+
261+ The gradient for a complex function is computed assuming the input function is a holomorphic function.
262+ This is because for general :math: `ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
263+ (as in the `2x2 ` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
264+ However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
265+ Cauchy-Riemann equations that ensure that `2x2 ` Jacobians have the special form of a scale-and-rotate
266+ matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
267+ obtain that gradient using backward which is just a call to `vjp ` with covector `1.0 `.
268+
269+ The net effect of this assumption is that the partial derivatives of the imaginary part of the function
270+ (:math: `v(x, y)` above) are discarded for :func: `torch.autograd.backward ` on a complex scalar
271+ (e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).
272+
273+ For any other desired behavior, you can specify the covector `grad_output ` in :func: `torch.autograd.backward ` call accordingly.
274+
275+ **How are the JVP and VJP defined for cross-domain functions? **
276+ ***************************************************************
277+
278+ Based on formulas above and the behavior we expect to see (going from :math: `ℂ → ℝ^2 → ℂ` should be an identity),
279+ we use the formula given below for cross-domain functions.
280+
281+ The :math: `JVP` and :math: `VJP` for a :math: `f1 : ℂ → ℝ^2 ` are defined as:
282+
283+ .. math :: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix}
284+
285+ .. math :: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}
286+
287+ The :math: `JVP` and :math: `VJP` for a :math: `f1 : ℝ^2 → ℂ` are defined as:
288+
289+ .. math :: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\
290+
291+ .. math :: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J
0 commit comments