Skip to content

Conversation

@vishwakftw
Copy link
Contributor

This closes #8940 .


Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) {
Tensor zero_mask = (exponent == 0.0);
return at::where(zero_mask, zeros_like(self), grad * exponent * self.pow(exponent - 1));

This comment was marked as off-topic.

This comment was marked as off-topic.

ssnl
ssnl previously requested changes Jun 27, 2018
Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need test in test_autograd.py

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@vadimkantorov
Copy link
Contributor

Does there exist a at:where overload that can accept a float tensor instead of binary tensor? If it exists, then the mask allocation can also be eliminated (if needed).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@vishwakftw
Copy link
Contributor Author

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented Jun 28, 2018

@vadimkantorov Unfortunately not. This might be a good addition (and pretty easy to add), if there are other cases where we might need it. Maybe we should generalize this into some sort of arbitrary equality test against a floating point number? (CC @colesbury @cpuhrsch for more opinions).

I do have a question beyond what this patch does: when the exponent is close to zero (but not exactly zero), what happens to the gradients? Is it numerically stable? If it's not, it would be nice (though not strictly necessary) to fix that too.

@vishwakftw
Copy link
Contributor Author

vishwakftw commented Jun 28, 2018

@ezyang As the exponent tends to 0, the derivative is 0 for non-zero x, and tends to inf for x = 0. This works correctly in PyTorch:

>>> a
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], requires_grad=True)
>>> a.pow(0.0001).sum().backward()
>>> a.grad
tensor([   inf, 0.0003, 0.0002, 0.0001, 0.0001, 0.0001, 0.0001, 0.0000, 0.0000,
        0.0000])
>>> a.pow(0.00001).sum().backward()
>>> a.grad
tensor([   inf, 0.0003, 0.0002, 0.0001, 0.0001, 0.0001, 0.0001, 0.0000, 0.0000,
        0.0000])
>>> a.pow(0.01).sum().backward()
>>> a.grad
tensor([   inf, 0.0103, 0.0052, 0.0035, 0.0026, 0.0021, 0.0017, 0.0015, 0.0013,
        0.0012])

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 28, 2018

@ezyang Thinking more about it, merging a comparison binary op and torch.where goes along the op fusion line (maybe it's a good case for the automatic fuser?). Comparison + torch.where is a frequent use-case I guess, but surfacing all comparison ops through where may bring unjustified api complexity.

The zero special case here would work because of a lucky coincidence: exponent needs to be compared to zero.

@ezyang
Copy link
Contributor

ezyang commented Jun 29, 2018

@vadimkantorov Certainly, "where" would be easy to support in the JIT fuser. I'd also be OK with a special case just for the zero test. Up to you guys!

@vishwakftw
Copy link
Contributor Author

I think for the purposes of this PR, the usage of where can remain as designed at the moment. Probably, when the semantics of where change, this part of the code can be revisited and modified accordingly. What do you think @ezyang ?

@ezyang
Copy link
Contributor

ezyang commented Jun 29, 2018

Works for me.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@vishwakftw vishwakftw deleted the pow-0-derivative branch June 29, 2018 13:55
goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
This closes pytorch#8940 .
Closes pytorch#8945

Differential Revision: D8668853

Pulled By: ezyang

fbshipit-source-id: 80a629352ee2f506c38a05647b769281579a5af7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Issue in a particular case in backpropagation

5 participants