Skip to content

Commit 31aad4f

Browse files
committed
add tests, copy diagonal code to backward for double differentiability
1 parent 7bf93af commit 31aad4f

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

test/test_torch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,21 @@ def test_diagonal_multidim(self):
19281928
expected = xn.diagonal(0, -2, -1)
19291929
self.assertEqual(expected.shape, result.shape)
19301930
self.assertTrue(np.allclose(expected, result.numpy()))
1931+
# test non-continguous
1932+
xp = x.permute(1, 2, 3, 0)
1933+
result = torch.diagonal(xp, 0, -2, -1)
1934+
expected = xp.numpy().diagonal(0, -2, -1)
1935+
self.assertEqual(expected.shape, result.shape)
1936+
self.assertTrue(np.allclose(expected, result.numpy()))
1937+
# test that the backward requires grad
1938+
# we do this is because diagonal_backward uses inplace
1939+
# operations and gradgradcheck does not catch whether
1940+
# they works as expected
1941+
a = torch.randn(5, 6, requires_grad=True)
1942+
b = torch.diagonal(a)**2
1943+
c = b.sum()
1944+
d, = torch.autograd.grad(c,a, retain_graph=True, create_graph=True)
1945+
self.assertTrue(d.requires_grad)
19311946

19321947
@staticmethod
19331948
def _test_diagflat(self, dtype, device):

tools/autograd/templates/Functions.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,9 +717,36 @@ Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal)
717717
return grad_input;
718718
}
719719

720-
Tensor diagonal_backward(const Tensor & grad, IntList input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
720+
Tensor diagonal_backward(const Tensor & grad, IntList input_sizes, int64_t offset, int64_t dim1_, int64_t dim2_) {
721721
auto grad_input = at::zeros(grad.type(), input_sizes);
722-
auto diag = at::diagonal(grad_input, offset, dim1, dim2);
722+
// the following until the assignment of auto diag
723+
// copies the diagonal code in aten/src/ATen/native/TensorShape.cpp
724+
// that would be equivalent to
725+
// auto diag = grad_input.diagonal(offset, dim1, dim2);
726+
// when using diagonal, the output is not differentiable twice
727+
// while this works
728+
int64_t nDims = input_sizes.size();
729+
int64_t dim1 = at::maybe_wrap_dim(dim1_, nDims);
730+
int64_t dim2 = at::maybe_wrap_dim(dim2_, nDims);
731+
int64_t diag_size;
732+
int64_t storage_offset = grad_input.storage_offset();
733+
if (offset >= 0) {
734+
diag_size = std::min(grad_input.size(dim1), grad_input.size(dim2)-offset);
735+
storage_offset += offset * grad_input.stride(dim2);
736+
} else {
737+
diag_size = std::min(grad_input.size(dim1)+offset, grad_input.size(dim2));
738+
storage_offset -= offset * grad_input.stride(dim1);
739+
}
740+
auto sizes = std::vector<int64_t>(grad_input.sizes());
741+
auto strides = std::vector<int64_t>(grad_input.strides());
742+
sizes.erase(sizes.begin() + std::max(dim1, dim2));
743+
strides.erase(strides.begin() + std::max(dim1, dim2));
744+
sizes.erase(sizes.begin() + std::min(dim1, dim2));
745+
strides.erase(strides.begin() + std::min(dim1, dim2));
746+
sizes.push_back(diag_size);
747+
strides.push_back(grad_input.stride(dim1)+grad_input.stride(dim2));
748+
auto diag = grad_input.as_strided(sizes, strides, storage_offset);
749+
723750
diag.copy_(grad);
724751
return grad_input;
725752
}

0 commit comments

Comments
 (0)