Skip to content

Follow the JAX or Tensorflow convention for meaning of grad() with complex inputs #41857

@ezyang

Description

@ezyang

cc #755

JAX and Tensorflow disagree about whether or not the grad of a complex tensor should be conjugated or not. Here is an easy way to see the difference:

from jax import grad

def f(z):
  return z * z

z = 0.1j

print(grad(f, holomorphic=True)(z))

gives 0.2j

However

>>> x = tf.Variable(0. + 0.j)
>>> sess.run(tf.gradients(x*x, x), feed_dict={x:0.1j})
[-0.20000000000000001j]

source: tensorflow/tensorflow#3348

PyTorch also has to decide which side of the field it will come down on. Right now on master it implements JAX. From reading the issue, here is my understanding of the pros and cons:

In favor of TF:

  • The gradient is the correct direction for doing gradient descent. This means you can use a "stock" optimizer (one that was written with only real parameters in mind) without any changes to the optimizer. In contrast, to do gradient descent with the JAX definition, you have to remember to conjugate first.

In favor of JAX:

  • When the TF definition is implemented directly in the gradient formulas (as opposed to doing a single post facto conjugation), you end up with ugly gradient formulas. For example, take a look at https://github.com/tensorflow/tensorflow/blob/70fd0a4436e3b49139653dc5b85d1c7df23f403d/tensorflow/python/ops/math_grad.py#L453 where TF has to explicitly conjugate the input. With the JAX style definition, you can mostly reuse your real gradient formulas.
  • The TF definition is less efficient, since you're doing extra conjugations in the gradient formulas. Even if you write the derivative formulas in the JAX way and then conjugate before plopping the gradient in x.grad, you lose the opportunity to do a fused conjugate-add for the optimizer update.

Gradient formulas look ugly seems like a clear reason to prefer JAX style. Posting this for other opinions.

cc @ezyang @gchanan @zou3519 @ssnl @albanD @gqchen @anjali411 @dylanbespalko @vincentqb

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generalmodule: complexRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions