-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix x.pow(0) gradient when x contains 0 #8945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) { | ||
| Tensor zero_mask = (exponent == 0.0); | ||
| return at::where(zero_mask, zeros_like(self), grad * exponent * self.pow(exponent - 1)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need test in test_autograd.py
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Does there exist a |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
|
@vadimkantorov Unfortunately not. This might be a good addition (and pretty easy to add), if there are other cases where we might need it. Maybe we should generalize this into some sort of arbitrary equality test against a floating point number? (CC @colesbury @cpuhrsch for more opinions). I do have a question beyond what this patch does: when the exponent is close to zero (but not exactly zero), what happens to the gradients? Is it numerically stable? If it's not, it would be nice (though not strictly necessary) to fix that too. |
|
@ezyang As the exponent tends to 0, the derivative is 0 for non-zero >>> a
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], requires_grad=True)
>>> a.pow(0.0001).sum().backward()
>>> a.grad
tensor([ inf, 0.0003, 0.0002, 0.0001, 0.0001, 0.0001, 0.0001, 0.0000, 0.0000,
0.0000])
>>> a.pow(0.00001).sum().backward()
>>> a.grad
tensor([ inf, 0.0003, 0.0002, 0.0001, 0.0001, 0.0001, 0.0001, 0.0000, 0.0000,
0.0000])
>>> a.pow(0.01).sum().backward()
>>> a.grad
tensor([ inf, 0.0103, 0.0052, 0.0035, 0.0026, 0.0021, 0.0017, 0.0015, 0.0013,
0.0012]) |
|
@ezyang Thinking more about it, merging a comparison binary op and The zero special case here would work because of a lucky coincidence: |
|
@vadimkantorov Certainly, "where" would be easy to support in the JIT fuser. I'd also be OK with a special case just for the zero test. Up to you guys! |
|
I think for the purposes of this PR, the usage of |
|
Works for me. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This closes pytorch#8940 . Closes pytorch#8945 Differential Revision: D8668853 Pulled By: ezyang fbshipit-source-id: 80a629352ee2f506c38a05647b769281579a5af7
This closes #8940 .