-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Closed
Copy link
Labels
high prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
cc #755
JAX and Tensorflow disagree about whether or not the grad of a complex tensor should be conjugated or not. Here is an easy way to see the difference:
from jax import grad
def f(z):
return z * z
z = 0.1j
print(grad(f, holomorphic=True)(z))
gives 0.2j
However
>>> x = tf.Variable(0. + 0.j)
>>> sess.run(tf.gradients(x*x, x), feed_dict={x:0.1j})
[-0.20000000000000001j]
source: tensorflow/tensorflow#3348
PyTorch also has to decide which side of the field it will come down on. Right now on master it implements JAX. From reading the issue, here is my understanding of the pros and cons:
In favor of TF:
- The gradient is the correct direction for doing gradient descent. This means you can use a "stock" optimizer (one that was written with only real parameters in mind) without any changes to the optimizer. In contrast, to do gradient descent with the JAX definition, you have to remember to conjugate first.
In favor of JAX:
- When the TF definition is implemented directly in the gradient formulas (as opposed to doing a single post facto conjugation), you end up with ugly gradient formulas. For example, take a look at https://github.com/tensorflow/tensorflow/blob/70fd0a4436e3b49139653dc5b85d1c7df23f403d/tensorflow/python/ops/math_grad.py#L453 where TF has to explicitly conjugate the input. With the JAX style definition, you can mostly reuse your real gradient formulas.
- The TF definition is less efficient, since you're doing extra conjugations in the gradient formulas. Even if you write the derivative formulas in the JAX way and then conjugate before plopping the gradient in x.grad, you lose the opportunity to do a fused conjugate-add for the optimizer update.
Gradient formulas look ugly seems like a clear reason to prefer JAX style. Posting this for other opinions.
cc @ezyang @gchanan @zou3519 @ssnl @albanD @gqchen @anjali411 @dylanbespalko @vincentqb
Metadata
Metadata
Assignees
Labels
high prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: complexRelated to complex number support in PyTorchRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module