Skip to content

Conversation

@anjali411
Copy link
Contributor

@anjali411 anjali411 commented Jul 6, 2020

Stack from ghstack:

Differential Revision: D22476911

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Jul 6, 2020
ghstack-source-id: 850637d
Pull Request resolved: #41012
@anjali411 anjali411 requested a review from albanD July 6, 2020 16:19
@anjali411 anjali411 changed the title complex autograd doc Autograd Doc for Complex Numbers Jul 6, 2020
@dr-ci
Copy link

dr-ci bot commented Jul 6, 2020

💊 CI failures summary and remediations

As of commit 2db4435 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 36 times.

anjali411 added a commit that referenced this pull request Jul 6, 2020
ghstack-source-id: 216fb84
Pull Request resolved: #41012
@anjali411 anjali411 requested a review from ezyang July 6, 2020 19:26
anjali411 added a commit that referenced this pull request Jul 7, 2020
ghstack-source-id: 23fe293
Pull Request resolved: #41012
anjali411 added a commit that referenced this pull request Jul 7, 2020
ghstack-source-id: 8d199c0
Pull Request resolved: #41012
What happens if I call backward() on a complex scalar?
******************************************************

1. For holomorphic functions, you get the same result as expected from using Cauchy-Riemann equations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear what "same result" means here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For holomorphic functions, the gradient can be fully represented with complex numbers due to the Cauchy-Riemann equations

PyTorch follows `_JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`
convention for autograd for Complex Numbers.

For a function :math:`F: C → C`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose we have a function F: C -> C which we can decompose into functions u and v which compute the real and imaginary parts of the function:

x, y = real(z), imag(z)
return u(x, y) + v(x, y) * 1j
The JVP and VJP for function :math:`F` are defined as:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for function F at (x, y)

def VJP(cotangent):
c, d = real(cotangent), imag(cotangent)
return \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -i \end{bmatrix}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PyTorch, the VJP is mostly what we care about, as it is the computation performed when we do backwards mode automatic differentiation. Notice that d and i are negated in the formula above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might suggest absorbing "How are the JVP and VJP defined for :math:R^2 -> C and :math:C -> R^2 functions?" section into this one. The structure then is "Here is the general definition", and then "Here is a particular example"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just felt it would be cleaner to have them in separate sections primarily because not everyone looking to get information about autograd for complex functions would be interested in cross domain functions

******************************************************

1. For holomorphic functions, you get the same result as expected from using Cauchy-Riemann equations.
2. For non-holomorphic functions, the partial derivatives of :math:`v(x, y)` are discarded.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of the imaginary part of the function (v(x, y) above) are discarded (e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards). To get the gradient with respect to the imaginary components of the function, you must explicitly specify gradient=torch.tensor(1j))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's easier to read. updated!
not sure what you mean by:

To get the gradient with respect to the imaginary components of the function, you must explicitly specify gradient=torch.tensor(1j))

do you mean to say for any other desired behavior, specify the grad_out accordingly?

How are the JVP and VJP defined for :math:`R^2 -> C` and :math:`C -> R^2` functions?
************************************************************************************

The JVP and VJP for a :math:`f1: C → R^2` are defined as:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still a little confused about the function of this section. Are we trying to explain why a conjugation occurs when we define the vjp for view_as_complex and view_as_rule? If so, it feels like it would be more direct if we directly talked about that particular case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal here was to specify the formulas being used in view_as_real and view_as_complex yeah and then also generally give people an idea on how they can define their own backward for similar functions or others.

def JVP(tangent):
c, d = real(tangent), imag(tangent)
return [1, i]^T * J * [c, d]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have mixed use of i and 1j throughout, you should update to use a single one all over. And maybe at the beginning add a quick "(in this document, we use blah to define the imaginary number)".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're still mixing j and 1j, also I can't see where you defined it?

**************************************************

For a function F: V → W, where are V and W are vector spaces. The output of
the Vector-Jacobian Product :math:`VJP : V → (W^* → V^*)` is a linear map
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit the VJP you define above is not this function here.
The one you define above is directly the linear mapping (because you hard coded the input in space V directly into J).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah good catch updated!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was this updated?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the "output of", I imagine

the Vector-Jacobian Product :math:`VJP : V → (W^* → V^*)` is a linear map
from :math:`W^* → V^*` (explained in `Chapter 4 of Dougal’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf>`_).

The negative signs in the above VJP computation are due to conjugation. The first
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this paragraph brings?
Why do you need the conjugation?
If the thesis contains all the details, maybe just replace this paragraph by a sentence pointing to the thesis for more details.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need conjugation because the vectors from the dual space as explained below

What happens if I call backward() on a complex scalar?
******************************************************

For geneneral ℂ→ℂ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is missing one extra step:
The backward in pytorch does not compute a jacobian.

I think you want something that mentions that for R->R and C->R, backward() (with no argument) computes the full gradient of the function.
But for C->C functions, this is not case...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed but we are not saying the backward is computing the grad. We are just commenting on the Jacobian that's used to summarize that why the gradient for holomorphic functions can still be represented a s a complex number

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then I guess the title of the section is what confused me here.

anjali411 added a commit that referenced this pull request Jul 7, 2020
ghstack-source-id: 4386973
Pull Request resolved: #41012
@anjali411 anjali411 requested a review from albanD July 9, 2020 16:23
anjali411 added a commit that referenced this pull request Jul 9, 2020
ghstack-source-id: 8fd0118
Pull Request resolved: #41012
x, y = real(z), imag(z)
return u(x, y) + v(x, y) * 1j
where *1j* is a unit imaginary number.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why bold and not in math here?

.. math::
J = \begin{bmatrix}
\partial_0u(x, y) & \partial_1u(x, y)\\
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is \partial_0 defined?
Maybe \frac{\partial u(x, y)}{\partial x} ?

**************************************************

For a function F: V → W, where are V and W are vector spaces. The output of
the Vector-Jacobian Product :math:`VJP : V → (W^* → V^*)` is a linear map
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was this updated?

the Vector-Jacobian Product :math:`VJP : V → (W^* → V^*)` is a linear map
from :math:`W^* → V^*` (explained in `Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf>`_).

The negative signs in the above `VJP` computation are due to conjugation. The first
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is clear what you mean by "the first vector in the output" here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm okay rewrote it. let me know if that looks better

vector in the output returned by `VJP` for a given cotangent is a covector (:math:`\in ℂ^*`),
and the last vector in the output is used to get the result in :math:`ℂ`
since the final result of reverse-mode differentiation of a function is a covector belonging
to :math:`ℂ^*` (explained in `Chapter 4 of Dougal Maclaurin’s thesis <https://dougalmaclaurin.com/phd-thesis.pdf>`_).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still unsure of the value of this paragraph. It seems to justify the conjugation by saying that we need the output to be in C. But if we just want the output to be in C, we could do without the conjugation no?
Is the justification here that the mapping we use from C* to C is defined based on the standard dot product on C. And so it maps vectors by doing hermitian transpose.

*******************************************************************************

The gradient for a complex function is computed assuming the input function is a holomorphic function.
This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading more about this, it feels like the "pure" C function definition does not hold.
And to be able to get quantities that match the gradients for holomorphic functions and provide sensible direction for use in gradient descent algorithms, modified definitions are needed.
Unfortunately, these definitions represent the full information about the derivatives of a general complex function using twice as many elements as a regular R -> R gradient we are used to.
And so we cannot expect to get these "extended gradients" in our framework where the .grad field is hard-coded to be the same size as the Tensor it represent the gradients of.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "pure" C function definition does not hold?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I call "pure C function definition" here is what happens if you apply the derivation definition (as a limit) from real functions to complex functions: you get derivatives only for holomorphic functions.

anjali411 added a commit that referenced this pull request Jul 9, 2020
ghstack-source-id: ea1417c
Pull Request resolved: #41012
@anjali411 anjali411 requested review from albanD and ezyang July 9, 2020 19:53
**Why is there a negative sign in the formula above?**
******************************************************

For a function F: V → W, where are V and W are vector spaces. The output of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:math:`F: V -> W`

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on offline sync with @albanD, I removed this section to avoid providing fuzzy or possibly incorrect explanation

where :math:`1j` is a unit imaginary number.

We define the JVP for function :math:`F` at :math:`(x, y)` applied to a tangent
vector :math:`c+dj \in C` as:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's probably better to explicitly say that this is Python pseudocode. For example, you couldn't actually use * in a real PyTorch program as that would give you pointwise multiplication, not matrix product. And the bracket syntax means something else in Python, that is also not intended here either.

Or even better, make this real code using the real PyTorch operations. Then you don't have to explain the pseudocode syntax.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this section to just use math notation instead. It's simpler to read and avoids the confusion

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the biggest issue in my mind is explanation of the pseudocode syntax, but I'm going to approve right now to move things along.

anjali411 added a commit that referenced this pull request Jul 10, 2020
ghstack-source-id: bfc4145
Pull Request resolved: #41012
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It think this has the main information we want here: the formula we should use for implementing other complex ops.
We can detail it and add more justification as we go.

*******************************************************************************

The gradient for a complex function is computed assuming the input function is a holomorphic function.
This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I call "pure C function definition" here is what happens if you apply the derivation definition (as a limit) from real functions to complex functions: you get derivatives only for holomorphic functions.

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in db38487.

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in db38487.

@facebook-github-bot facebook-github-bot deleted the gh/anjali411/39/head branch July 13, 2020 17:56
malfet pushed a commit that referenced this pull request Jul 22, 2020
Summary: Pull Request resolved: #41012

Test Plan: Imported from OSS

Differential Revision: D22476911

Pulled By: anjali411

fbshipit-source-id: 7da20cb4312a0465272bebe053520d9911475828
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants