@@ -1913,36 +1913,20 @@ def test_diagonal(self):
19131913 def test_diagonal_multidim (self ):
19141914 x = torch .randn (10 , 11 , 12 , 13 )
19151915 xn = x .numpy ()
1916- result = torch .diagonal (x , 2 , 2 , 3 )
1917- expected = xn .diagonal (2 , 2 , 3 )
1918- self .assertEqual (expected .shape , result .shape )
1919- self .assertTrue (np .allclose (expected , result .numpy ()))
1920- result = torch .diagonal (x , 2 )
1921- expected = torch .diagonal (x , 2 , 0 , 1 )
1922- self .assertEqual (expected , result )
1923- result = torch .diagonal (x , - 2 , 1 , 2 )
1924- expected = xn .diagonal (- 2 , 1 , 2 )
1925- self .assertEqual (expected .shape , result .shape )
1926- self .assertTrue (np .allclose (expected , result .numpy ()))
1927- result = torch .diagonal (x , 0 , - 2 , - 1 )
1928- expected = xn .diagonal (0 , - 2 , - 1 )
1929- self .assertEqual (expected .shape , result .shape )
1930- self .assertTrue (np .allclose (expected , result .numpy ()))
1916+ for args in [(2 , 2 , 3 ),
1917+ (2 ,),
1918+ (- 2 , 1 , 2 ),
1919+ (0 , - 2 , - 1 )]:
1920+ result = torch .diagonal (x , * args )
1921+ expected = xn .diagonal (* args )
1922+ self .assertEqual (expected .shape , result .shape )
1923+ self .assertTrue (np .allclose (expected , result .numpy ()))
19311924 # test non-continguous
19321925 xp = x .permute (1 , 2 , 3 , 0 )
19331926 result = torch .diagonal (xp , 0 , - 2 , - 1 )
19341927 expected = xp .numpy ().diagonal (0 , - 2 , - 1 )
19351928 self .assertEqual (expected .shape , result .shape )
19361929 self .assertTrue (np .allclose (expected , result .numpy ()))
1937- # test that the backward requires grad
1938- # we do this is because diagonal_backward uses inplace
1939- # operations and gradgradcheck does not catch whether
1940- # they works as expected
1941- a = torch .randn (5 , 6 , requires_grad = True )
1942- b = torch .diagonal (a )** 2
1943- c = b .sum ()
1944- d , = torch .autograd .grad (c ,a , retain_graph = True , create_graph = True )
1945- self .assertTrue (d .requires_grad )
19461930
19471931 @staticmethod
19481932 def _test_diagflat (self , dtype , device ):
0 commit comments