@@ -998,7 +998,7 @@ def test_remainder(self):
998998 long_m1 = torch .LongTensor (10 , 10 ).random_ (- 10 , 10 )
999999 long_res1 = long_m1 .clone ()
10001000 long_res2 = long_m1 .clone ()
1001- long_qs = torch .arange (- 5 , 5 ). long ()
1001+ long_qs = torch .arange (- 5 , 5 )
10021002 long_qs [5 ] = 5 # Can't handle the divisor=0 case
10031003 for col_idx , long_q in enumerate (long_qs ):
10041004 # Reference
@@ -2313,6 +2313,39 @@ def test_arange(self):
23132313 self .assertEqual (r1 , r2 , 0 )
23142314 self .assertEqual (r2 , r3 [:- 1 ], 0 )
23152315
2316+ def test_arange_inference (self ):
2317+ saved_dtype = torch .get_default_dtype ()
2318+ torch .set_default_dtype (torch .float32 )
2319+ # end only
2320+ self .assertIs (torch .float32 , torch .arange (1. ).dtype )
2321+ self .assertIs (torch .float32 , torch .arange (torch .tensor (1. )).dtype )
2322+ self .assertIs (torch .float32 , torch .arange (torch .tensor (1. , dtype = torch .float64 )).dtype )
2323+
2324+ self .assertIs (torch .int64 , torch .arange (1 ).dtype )
2325+ self .assertIs (torch .int64 , torch .arange (torch .tensor (1 )).dtype )
2326+ self .assertIs (torch .int64 , torch .arange (torch .tensor (1 , dtype = torch .int16 )).dtype )
2327+
2328+ # start, end, [step]
2329+ self .assertIs (torch .float32 , torch .arange (1. , 3 ).dtype )
2330+ self .assertIs (torch .float32 , torch .arange (torch .tensor (1. , dtype = torch .float64 ), 3 ).dtype )
2331+ self .assertIs (torch .float32 , torch .arange (1 , 3. ).dtype )
2332+ self .assertIs (torch .float32 , torch .arange (torch .tensor (1 , dtype = torch .int16 ), torch .tensor (3. )).dtype )
2333+ self .assertIs (torch .float32 , torch .arange (1 , 3 , 1. ).dtype )
2334+ self .assertIs (torch .float32 ,
2335+ torch .arange (torch .tensor (1 ),
2336+ torch .tensor (3 , dtype = torch .int16 ),
2337+ torch .tensor (1. , dtype = torch .float64 )).dtype )
2338+
2339+ self .assertIs (torch .int64 , torch .arange (1 , 3 ).dtype )
2340+ self .assertIs (torch .int64 , torch .arange (torch .tensor (1 ), 3 ).dtype )
2341+ self .assertIs (torch .int64 , torch .arange (torch .tensor (1 ), torch .tensor (3 , dtype = torch .int16 )).dtype )
2342+ self .assertIs (torch .int64 , torch .arange (1 , 3 , 1 ).dtype )
2343+ self .assertIs (torch .int64 ,
2344+ torch .arange (torch .tensor (1 ),
2345+ torch .tensor (3 ),
2346+ torch .tensor (1 , dtype = torch .int16 )).dtype )
2347+ torch .set_default_dtype (saved_dtype )
2348+
23162349 @staticmethod
23172350 def _select_broadcastable_dims (dims_full = None ):
23182351 # select full dimensionality
@@ -2883,7 +2916,7 @@ def test_median(self):
28832916 self .assertEqual (x , x0 , 0 )
28842917
28852918 def test_mode (self ):
2886- x = torch .arange (1 , SIZE * SIZE + 1 ).clone ().resize_ (SIZE , SIZE )
2919+ x = torch .arange (1. , SIZE * SIZE + 1 ).clone ().resize_ (SIZE , SIZE )
28872920 x [:2 ] = 1
28882921 x [:, :2 ] = 1
28892922 x0 = x .clone ()
@@ -3119,7 +3152,7 @@ def test_randn(self):
31193152
31203153 def test_slice (self ):
31213154 empty = torch .Tensor ()
3122- x = torch .arange (0 , 16 ).view (4 , 4 )
3155+ x = torch .arange (0. , 16 ).view (4 , 4 )
31233156 self .assertEqual (x .slice (), x )
31243157 self .assertEqual (x .slice (0 , 0 , 4 ), x )
31253158 # start and stop are clamped to the size of dim
@@ -3914,7 +3947,7 @@ def naive_stft(x, frame_length, hop, fft_size=None, normalized=False,
39143947 return_size = fft_size
39153948 result = x .new (batch , int ((length - frame_length ) / float (hop )) + 1 , return_size , 2 )
39163949 for w in range (return_size ): # freq
3917- radians = torch .arange (frame_length ) * w * 2 * math .pi / fft_size
3950+ radians = torch .arange (float ( frame_length ) ) * w * 2 * math .pi / fft_size
39183951 radians = radians .type_as (x )
39193952 re_kernel = radians .cos ().mul_ (window )
39203953 im_kernel = - radians .sin ().mul_ (window )
@@ -4576,7 +4609,7 @@ def ri(indices):
45764609 # strided is [[1 3 5 7],
45774610 # [9 11 13 15]]
45784611
4579- reference = conv_fn (torch .arange (0 , 24 ).view (3 , 8 ))
4612+ reference = conv_fn (torch .arange (0. , 24 ).view (3 , 8 ))
45804613 strided = conv_fn (torch .Tensor ())
45814614 strided .set_ (reference .storage (), 1 , size = torch .Size ([2 , 4 ]),
45824615 stride = [8 , 2 ])
@@ -4614,15 +4647,15 @@ def ri(indices):
46144647 # strided is [[10, 11],
46154648 # [17, 18]]
46164649
4617- reference = conv_fn (torch .arange (0 , 24 ).view (3 , 8 ))
4650+ reference = conv_fn (torch .arange (0. , 24 ).view (3 , 8 ))
46184651 strided = conv_fn (torch .Tensor ())
46194652 strided .set_ (reference .storage (), 10 , size = torch .Size ([2 , 2 ]),
46204653 stride = [7 , 1 ])
46214654 self .assertEqual (strided [ri ([0 ]), ri ([1 ])], torch .Tensor ([11 ]))
46224655 strided [ri ([0 ]), ri ([1 ])] = - 1
46234656 self .assertEqual (strided [ri ([0 ]), ri ([1 ])], torch .Tensor ([- 1 ]))
46244657
4625- reference = conv_fn (torch .arange (0 , 24 ).view (3 , 8 ))
4658+ reference = conv_fn (torch .arange (0. , 24 ).view (3 , 8 ))
46264659 strided = conv_fn (torch .Tensor ())
46274660 strided .set_ (reference .storage (), 10 , size = torch .Size ([2 , 2 ]),
46284661 stride = [7 , 1 ])
@@ -4632,7 +4665,7 @@ def ri(indices):
46324665 self .assertEqual (strided [ri ([0 , 1 ]), ri ([1 , 0 ])], torch .Tensor ([- 1 ,
46334666 2 ]))
46344667
4635- reference = conv_fn (torch .arange (0 , 24 ).view (3 , 8 ))
4668+ reference = conv_fn (torch .arange (0. , 24 ).view (3 , 8 ))
46364669 strided = conv_fn (torch .Tensor ())
46374670 strided .set_ (reference .storage (), 10 , size = torch .Size ([2 , 2 ]),
46384671 stride = [7 , 1 ])
@@ -4727,7 +4760,7 @@ def get_set_tensor(indexed, indexer):
47274760 # 5 6 7 8 9
47284761 # 10 11 12 13 14
47294762 # 15 16 17 18 19
4730- reference = conv_fn (torch .arange (0 , 20 ).view (4 , 5 ))
4763+ reference = conv_fn (torch .arange (0. , 20 ).view (4 , 5 ))
47314764
47324765 indices_to_test = [
47334766 # grab the second, fourth columns
@@ -4753,7 +4786,7 @@ def get_set_tensor(indexed, indexer):
47534786 indexer ,
47544787 get_set_tensor (reference , indexer ))
47554788
4756- reference = conv_fn (torch .arange (0 , 160 ).view (4 , 8 , 5 ))
4789+ reference = conv_fn (torch .arange (0. , 160 ).view (4 , 8 , 5 ))
47574790
47584791 indices_to_test = [
47594792 [slice (None ), slice (None ), [0 , 3 , 4 ]],
@@ -4804,7 +4837,7 @@ def get_set_tensor(indexed, indexer):
48044837 indexer ,
48054838 get_set_tensor (reference , indexer ))
48064839
4807- reference = conv_fn (torch .arange (0 , 1296 ).view (3 , 9 , 8 , 6 ))
4840+ reference = conv_fn (torch .arange (0. , 1296 ).view (3 , 9 , 8 , 6 ))
48084841
48094842 indices_to_test = [
48104843 [slice (None ), slice (None ), slice (None ), [0 , 3 , 4 ]],
0 commit comments