Skip to content

Commit a7d0b1f

Browse files
committed
improve tests and doc comment. Thank you, Adam!
1 parent 31aad4f commit a7d0b1f

File tree

3 files changed

+22
-26
lines changed

3 files changed

+22
-26
lines changed

test/test_autograd.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,6 +2116,18 @@ def test_mul_out_result_requires_grad(self):
21162116
# we should throw an exception if the output requires grad
21172117
self.assertRaisesRegex(RuntimeError, 'out=', lambda: torch.mul(a, b, out=x))
21182118

2119+
def test_diagonal_derivative_requires_grad(self):
2120+
# test that the backward requires grad
2121+
# we do this is because diagonal_backward uses inplace
2122+
# operations and gradgradcheck does not catch whether
2123+
# they works as expected (it will succeed even if
2124+
# the gradient has requires_grad == False
2125+
a = torch.randn(5, 6, requires_grad=True)
2126+
b = torch.diagonal(a)**2
2127+
c = b.sum()
2128+
d, = torch.autograd.grad(c, a, retain_graph=True, create_graph=True)
2129+
self.assertTrue(d.requires_grad)
2130+
21192131

21202132
def index_variable(shape, max_indices):
21212133
if not isinstance(shape, tuple):

test/test_torch.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,36 +1913,20 @@ def test_diagonal(self):
19131913
def test_diagonal_multidim(self):
19141914
x = torch.randn(10, 11, 12, 13)
19151915
xn = x.numpy()
1916-
result = torch.diagonal(x, 2, 2, 3)
1917-
expected = xn.diagonal(2, 2, 3)
1918-
self.assertEqual(expected.shape, result.shape)
1919-
self.assertTrue(np.allclose(expected, result.numpy()))
1920-
result = torch.diagonal(x, 2)
1921-
expected = torch.diagonal(x, 2, 0, 1)
1922-
self.assertEqual(expected, result)
1923-
result = torch.diagonal(x, -2, 1, 2)
1924-
expected = xn.diagonal(-2, 1, 2)
1925-
self.assertEqual(expected.shape, result.shape)
1926-
self.assertTrue(np.allclose(expected, result.numpy()))
1927-
result = torch.diagonal(x, 0, -2, -1)
1928-
expected = xn.diagonal(0, -2, -1)
1929-
self.assertEqual(expected.shape, result.shape)
1930-
self.assertTrue(np.allclose(expected, result.numpy()))
1916+
for args in [(2, 2, 3),
1917+
(2,),
1918+
(-2, 1, 2),
1919+
(0, -2, -1)]:
1920+
result = torch.diagonal(x, *args)
1921+
expected = xn.diagonal(*args)
1922+
self.assertEqual(expected.shape, result.shape)
1923+
self.assertTrue(np.allclose(expected, result.numpy()))
19311924
# test non-continguous
19321925
xp = x.permute(1, 2, 3, 0)
19331926
result = torch.diagonal(xp, 0, -2, -1)
19341927
expected = xp.numpy().diagonal(0, -2, -1)
19351928
self.assertEqual(expected.shape, result.shape)
19361929
self.assertTrue(np.allclose(expected, result.numpy()))
1937-
# test that the backward requires grad
1938-
# we do this is because diagonal_backward uses inplace
1939-
# operations and gradgradcheck does not catch whether
1940-
# they works as expected
1941-
a = torch.randn(5, 6, requires_grad=True)
1942-
b = torch.diagonal(a)**2
1943-
c = b.sum()
1944-
d, = torch.autograd.grad(c,a, retain_graph=True, create_graph=True)
1945-
self.assertTrue(d.requires_grad)
19461930

19471931
@staticmethod
19481932
def _test_diagflat(self, dtype, device):

torch/_torch_docs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,8 +1313,8 @@
13131313
-0.2239
13141314
[torch.FloatTensor of size 2]
13151315
1316-
>>> x = torch.randn(2,5,4,2)
1317-
>>> torch.diagonal(x, -1, 1, 2)
1316+
>>> x = torch.randn(2, 5, 4, 2)
1317+
>>> torch.diagonal(x, offset=-1, dim1=1, dim2=2)
13181318
13191319
(0 ,.,.) =
13201320
-0.6806 -0.0281 -0.6595 -0.4199

0 commit comments

Comments
 (0)