@@ -3748,19 +3748,48 @@ def test_expand(self):
37483748 self .assertEqual (torch .randn (()).expand (()), torch .randn (()))
37493749
37503750 def test_repeat (self ):
3751- result = torch .Tensor ()
3752- tensor = torch .rand (8 , 4 )
3751+
3752+ initial_shape = (8 , 4 )
3753+ tensor = torch .rand (* initial_shape )
3754+
37533755 size = (3 , 1 , 1 )
37543756 torchSize = torch .Size (size )
37553757 target = [3 , 8 , 4 ]
37563758 self .assertEqual (tensor .repeat (* size ).size (), target , 'Error in repeat' )
3757- self .assertEqual (tensor .repeat (torchSize ).size (), target , 'Error in repeat using LongStorage' )
3759+ self .assertEqual (tensor .repeat (torchSize ).size (), target ,
3760+ 'Error in repeat using LongStorage' )
37583761 result = tensor .repeat (* size )
37593762 self .assertEqual (result .size (), target , 'Error in repeat using result' )
37603763 result = tensor .repeat (torchSize )
37613764 self .assertEqual (result .size (), target , 'Error in repeat using result and LongStorage' )
37623765 self .assertEqual (result .mean (0 ).view (8 , 4 ), tensor , 'Error in repeat (not equal)' )
37633766
3767+ @unittest .skipIf (not TEST_NUMPY , "Numpy not found" )
3768+ def test_repeat_tile (self ):
3769+
3770+ initial_shape = (8 , 4 )
3771+
3772+ repeats = ((3 , 1 , 1 ),
3773+ (3 , 3 , 3 ),
3774+ (1 , 2 , 1 ),
3775+ (2 , 2 , 2 , 2 ))
3776+
3777+ def _generate_noncontiguous_input ():
3778+
3779+ out = np .broadcast_to (np .random .random ((1 , 4 )),
3780+ initial_shape )
3781+
3782+ assert not (out .flags .c_contiguous or out .flags .f_contiguous )
3783+
3784+ return out
3785+
3786+ for repeat in repeats :
3787+ for tensor in (torch .from_numpy (np .random .random (initial_shape )),
3788+ torch .from_numpy (_generate_noncontiguous_input ()),):
3789+
3790+ self .assertEqual (tensor .repeat (* repeat ).numpy (),
3791+ np .tile (tensor .numpy (), repeat ))
3792+
37643793 def test_is_same_size (self ):
37653794 t1 = torch .Tensor (3 , 4 , 9 , 10 )
37663795 t2 = torch .Tensor (3 , 4 )
0 commit comments