Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,6 +2030,16 @@ def run_test(input_size, norm_deg):
run_test((10,), 1)
run_test((10,), 1.5)

def test_pow_zero_tensor_gradient(self):
def run_test(input_size, exponent):
input = torch.zeros(*input_size, requires_grad=True)
input.pow(exponent).sum().backward()
self.assertEqual(input.grad.data.abs().sum(), 0)

run_test((10,), torch.zeros(10))
run_test((10, 10), torch.zeros(10, 10))
run_test((10,), 0)

def test_profiler(self):
x = torch.randn(10, 10)

Expand Down
6 changes: 3 additions & 3 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,11 @@
input2: not_implemented("potri")

- name: pow(Tensor self, Scalar exponent)
self: grad * exponent * self.pow(exponent.toDouble() - 1)
self: pow_backward(grad, self, exponent)

- name: pow(Tensor self, Tensor exponent)
self: grad * exponent * self.pow(exponent - 1)
exponent: grad * self.pow(exponent) * self.log()
self: pow_backward_self(grad, self, exponent)
exponent: pow_backward_exponent(grad, self, exponent)

- name: _prod(Tensor self, int64_t dim, bool keepdim)
self: prod_backward(grad, self, result, dim, keepdim)
Expand Down
17 changes: 17 additions & 0 deletions tools/autograd/templates/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ Tensor norm_backward(Tensor grad, const Tensor & self, const Scalar & p_, Tensor
return norm_backward(grad, self, p_, norm);
}

Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) {
double exponent = exponent_.toDouble();
if (exponent == 0.0) {
return zeros_like(self);
} else {
return grad * exponent * self.pow(exponent - 1);
}
}

Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) {
return at::where(exponent == 0.0, at::zeros({}, grad.type()), grad * exponent * self.pow(exponent - 1));
}

Tensor pow_backward_exponent(Tensor grad, const Tensor & self, const Tensor & exponent) {
return grad * self.pow(exponent) * self.log();
}

Tensor reduce_to(const Tensor & grad, IntList sizes) {
if (sizes.size() == 0) {
return grad.sum();
Expand Down