@@ -3828,10 +3828,11 @@ def test_var_stability(self):
38283828 self .assertEqual (tensor .var (0 )[0 ], 0.03125 )
38293829 self .assertEqual (tensor .var (), 0.03125 )
38303830
3831- def test_view (self ):
3832- tensor = torch .rand (15 )
3833- template = torch .rand (3 , 5 )
3834- empty = torch .Tensor ()
3831+ @staticmethod
3832+ def _test_view (self , cast ):
3833+ tensor = cast (torch .rand (15 ))
3834+ template = cast (torch .rand (3 , 5 ))
3835+ empty = cast (torch .Tensor ())
38353836 target = template .size ()
38363837 self .assertEqual (tensor .view_as (template ).size (), target )
38373838 self .assertEqual (tensor .view (3 , 5 ).size (), target )
@@ -3848,6 +3849,52 @@ def test_view(self):
38483849 self .assertRaises (RuntimeError , lambda : tensor .view (15 , 0 ))
38493850 self .assertRaises (RuntimeError , lambda : tensor .view (7 , - 1 ))
38503851 self .assertRaises (RuntimeError , lambda : tensor .view (15 , - 1 , - 1 ))
3852+ # test view when tensor is not contiguous in every dimension, but only
3853+ # contiguous dimensions are touched.
3854+ tensor = cast (torch .rand (4 , 2 , 5 , 1 , 6 , 2 , 9 , 3 )).transpose (- 1 , 2 ).transpose (- 2 , 3 )
3855+ # size: [ 4, 2, 3, 9, 6, 2, 1, 5]
3856+ # stride: [3840, 1620, 1, 3, 54, 27, 324, 324]
3857+ # contiguous dim chunks: [__________, ____, ____, __________, ____, ____]
3858+ # merging 1 to chunk after: [__________, ____, ____, __________, __________]
3859+ contig_tensor = tensor .clone ()
3860+ # [4, 2] => [8, 1]
3861+ # [3] => [3]
3862+ # [9] => [3, 3]
3863+ # [6, 2] => [4, 1, 3]
3864+ # [1, 5] => [5]
3865+ view_size = [8 , 1 , 3 , 3 , 3 , 4 , 1 , 3 , 5 ]
3866+ self .assertEqual (tensor .view (* view_size ), contig_tensor .view (* view_size ))
3867+ # [4, 2] => [2, 4]
3868+ # [3] => [3]
3869+ # [9] => [1, 9]
3870+ # [6, 2] => [2, 2, 3]
3871+ # [1, 5] => [5, 1]
3872+ view_size = [2 , 4 , 3 , 1 , 9 , 2 , 2 , 3 , 5 , 1 ]
3873+ self .assertEqual (tensor .view (* view_size ), contig_tensor .view (* view_size ))
3874+ # adding size 1 dims
3875+ view_size = [1 , 1 , 2 , 1 , 4 , 3 , 1 , 1 , 9 , 1 , 2 , 1 , 2 , 3 , 1 , 5 , 1 , 1 ]
3876+ self .assertEqual (tensor .view (* view_size ), contig_tensor .view (* view_size ))
3877+
3878+ # invalid views
3879+ self .assertRaises (RuntimeError , lambda : tensor .view (- 1 ))
3880+ # crossing [4, 2], [3]
3881+ self .assertRaises (RuntimeError , lambda : tensor .view (24 , 9 , 6 , 2 , 1 , 5 ))
3882+ # crossing [6, 2], [1, 5]
3883+ self .assertRaises (RuntimeError , lambda : tensor .view (8 , 3 , 9 , 6 , 10 ))
3884+ # crossing [9], [6, 2]
3885+ self .assertRaises (RuntimeError , lambda : tensor .view (8 , 3 , 54 , 2 , 1 , 5 ))
3886+
3887+ # view with stride 0 dims
3888+ tensor = cast (torch .Tensor (1 , 1 )).expand (3 , 4 ) # all dims are contiguous
3889+ contig_tensor = tensor .clone ()
3890+ self .assertEqual (tensor .view (- 1 ), contig_tensor .view (- 1 ))
3891+ self .assertEqual (tensor .view (1 , - 1 , 1 ), contig_tensor .view (1 , - 1 , 1 ))
3892+ self .assertEqual (tensor .view (- 1 , 1 ), contig_tensor .view (- 1 , 1 ))
3893+ self .assertEqual (tensor .view (6 , 2 , 1 ), contig_tensor .view (6 , 2 , 1 ))
3894+ self .assertEqual (tensor .view (1 , 6 , 2 , 1 ), contig_tensor .view (1 , 6 , 2 , 1 ))
3895+
3896+ def test_view (self ):
3897+ TestTorch ._test_view (self , lambda x : x )
38513898
38523899 def test_expand (self ):
38533900 tensor = torch .rand (1 , 8 , 1 )
0 commit comments