@@ -6720,6 +6720,8 @@ def _test_flip(self, use_cuda=False):
67206720 self .assertEqual (torch .tensor ([7 , 8 , 5 , 6 , 3 , 4 , 1 , 2 ]).view (2 , 2 , 2 ), data .flip (0 , 1 ))
67216721 self .assertEqual (torch .tensor ([8 , 7 , 6 , 5 , 4 , 3 , 2 , 1 ]).view (2 , 2 , 2 ), data .flip (0 , 1 , 2 ))
67226722
6723+ # check for wrap dim
6724+ self .assertEqual (torch .tensor ([2 , 1 , 4 , 3 , 6 , 5 , 8 , 7 ]).view (2 , 2 , 2 ), data .flip (- 1 ))
67236725 # check for permute
67246726 self .assertEqual (torch .tensor ([6 , 5 , 8 , 7 , 2 , 1 , 4 , 3 ]).view (2 , 2 , 2 ), data .flip (0 , 2 ))
67256727 self .assertEqual (torch .tensor ([6 , 5 , 8 , 7 , 2 , 1 , 4 , 3 ]).view (2 , 2 , 2 ), data .flip (2 , 0 ))
@@ -6730,8 +6732,6 @@ def _test_flip(self, use_cuda=False):
67306732 self .assertRaises (TypeError , lambda : data .flip ())
67316733 # not allow size of flip dim > total dims
67326734 self .assertRaises (RuntimeError , lambda : data .flip (0 , 1 , 2 , 3 ))
6733- # not allow dim < 0
6734- self .assertRaises (RuntimeError , lambda : data .flip (- 1 ))
67356735 # not allow dim > max dim
67366736 self .assertRaises (RuntimeError , lambda : data .flip (3 ))
67376737
@@ -6756,6 +6756,10 @@ def _test_flip(self, use_cuda=False):
67566756 self .assertEqual (flip0_result , data .flip (0 ))
67576757 self .assertEqual (flip1_result , data .flip (1 ))
67586758
6759+ # test empty tensor, should just return an empty tensor of the same shape
6760+ data = torch .tensor ([])
6761+ self .assertEqual (data , data .flip (0 ))
6762+
67596763 def test_flip (self ):
67606764 self ._test_flip (self , use_cuda = False )
67616765
@@ -6769,6 +6773,44 @@ def test_reversed(self):
67696773 val = torch .tensor (42 )
67706774 self .assertEqual (reversed (val ), torch .tensor (42 ))
67716775
6776+ @staticmethod
6777+ def _test_rot90 (self , use_cuda = False ):
6778+ device = torch .device ("cuda" if use_cuda else "cpu" )
6779+ data = torch .arange (1 , 5 , device = device ).view (2 , 2 )
6780+ self .assertEqual (torch .tensor ([1 , 2 , 3 , 4 ]).view (2 , 2 ), data .rot90 (0 , [0 , 1 ]))
6781+ self .assertEqual (torch .tensor ([2 , 4 , 1 , 3 ]).view (2 , 2 ), data .rot90 (1 , [0 , 1 ]))
6782+ self .assertEqual (torch .tensor ([4 , 3 , 2 , 1 ]).view (2 , 2 ), data .rot90 (2 , [0 , 1 ]))
6783+ self .assertEqual (torch .tensor ([3 , 1 , 4 , 2 ]).view (2 , 2 ), data .rot90 (3 , [0 , 1 ]))
6784+
6785+ # test for default args k=1, dims=[0, 1]
6786+ self .assertEqual (data .rot90 (), data .rot90 (1 , [0 , 1 ]))
6787+
6788+ # test for reversed order of dims
6789+ self .assertEqual (data .rot90 (3 , [0 , 1 ]), data .rot90 (1 , [1 , 0 ]))
6790+
6791+ # test for modulo of k
6792+ self .assertEqual (data .rot90 (5 , [0 , 1 ]), data .rot90 (1 , [0 , 1 ]))
6793+ self .assertEqual (data .rot90 (3 , [0 , 1 ]), data .rot90 (- 1 , [0 , 1 ]))
6794+ self .assertEqual (data .rot90 (- 5 , [0 , 1 ]), data .rot90 (- 1 , [0 , 1 ]))
6795+
6796+ # test for dims out-of-range error
6797+ self .assertRaises (RuntimeError , lambda : data .rot90 (1 , [0 , - 3 ]))
6798+ self .assertRaises (RuntimeError , lambda : data .rot90 (1 , [0 , 2 ]))
6799+
6800+ # test tensor with more than 2D
6801+ data = torch .arange (1 , 9 , device = device ).view (2 , 2 , 2 )
6802+ self .assertEqual (torch .tensor ([2 , 4 , 1 , 3 , 6 , 8 , 5 , 7 ]).view (2 , 2 , 2 ), data .rot90 (1 , [1 , 2 ]))
6803+ self .assertEqual (data .rot90 (1 , [1 , - 1 ]), data .rot90 (1 , [1 , 2 ]))
6804+
6805+ # test for errors
6806+ self .assertRaises (RuntimeError , lambda : data .rot90 (1 , [0 , 3 ]))
6807+ self .assertRaises (RuntimeError , lambda : data .rot90 (1 , [1 , 1 ]))
6808+ self .assertRaises (RuntimeError , lambda : data .rot90 (1 , [0 , 1 , 2 ]))
6809+ self .assertRaises (RuntimeError , lambda : data .rot90 (1 , [0 ]))
6810+
6811+ def test_rot90 (self ):
6812+ self ._test_rot90 (self , use_cuda = False )
6813+
67726814 def test_storage (self ):
67736815 v = torch .randn (3 , 5 )
67746816 self .assertEqual (v .storage ()[0 ], v .data [0 ][0 ])
0 commit comments