Skip to content

Commit d65b3bc

Browse files
vishwakftwRob Kunkle
authored andcommitted
Fix x.pow(0) gradient when x contains 0 (pytorch#8945)
Summary: This closes pytorch#8940 . Closes pytorch#8945 Differential Revision: D8668853 Pulled By: ezyang fbshipit-source-id: 80a629352ee2f506c38a05647b769281579a5af7
1 parent 4e454a3 commit d65b3bc

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

test/test_autograd.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,16 @@ def run_test(input_size, norm_deg):
20302030
run_test((10,), 1)
20312031
run_test((10,), 1.5)
20322032

2033+
def test_pow_zero_tensor_gradient(self):
2034+
def run_test(input_size, exponent):
2035+
input = torch.zeros(*input_size, requires_grad=True)
2036+
input.pow(exponent).sum().backward()
2037+
self.assertEqual(input.grad.data.abs().sum(), 0)
2038+
2039+
run_test((10,), torch.zeros(10))
2040+
run_test((10, 10), torch.zeros(10, 10))
2041+
run_test((10,), 0)
2042+
20332043
def test_profiler(self):
20342044
x = torch.randn(10, 10)
20352045

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,11 +529,11 @@
529529
input2: not_implemented("potri")
530530

531531
- name: pow(Tensor self, Scalar exponent)
532-
self: grad * exponent * self.pow(exponent.toDouble() - 1)
532+
self: pow_backward(grad, self, exponent)
533533

534534
- name: pow(Tensor self, Tensor exponent)
535-
self: grad * exponent * self.pow(exponent - 1)
536-
exponent: grad * self.pow(exponent) * self.log()
535+
self: pow_backward_self(grad, self, exponent)
536+
exponent: pow_backward_exponent(grad, self, exponent)
537537

538538
- name: _prod(Tensor self, int64_t dim, bool keepdim)
539539
self: prod_backward(grad, self, result, dim, keepdim)

tools/autograd/templates/Functions.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,23 @@ Tensor norm_backward(Tensor grad, const Tensor & self, const Scalar & p_, Tensor
108108
return norm_backward(grad, self, p_, norm);
109109
}
110110

111+
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) {
112+
double exponent = exponent_.toDouble();
113+
if (exponent == 0.0) {
114+
return zeros_like(self);
115+
} else {
116+
return grad * exponent * self.pow(exponent - 1);
117+
}
118+
}
119+
120+
Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) {
121+
return at::where(exponent == 0.0, at::zeros({}, grad.type()), grad * exponent * self.pow(exponent - 1));
122+
}
123+
124+
Tensor pow_backward_exponent(Tensor grad, const Tensor & self, const Tensor & exponent) {
125+
return grad * self.pow(exponent) * self.log();
126+
}
127+
111128
Tensor reduce_to(const Tensor & grad, IntList sizes) {
112129
if (sizes.size() == 0) {
113130
return grad.sum();

0 commit comments

Comments
 (0)