@@ -2790,8 +2790,8 @@ def test_randperm_cuda(self):
27902790 self .assertEqual (res1 , res2 , 0 )
27912791
27922792 with torch .random .fork_rng (devices = [0 ]):
2793- res1 = torch .randperm (50000 , dtype = torch .half , device = cuda )
2794- res2 = torch .cuda .HalfTensor ()
2793+ res1 = torch .randperm (50000 , dtype = torch .float , device = cuda )
2794+ res2 = torch .cuda .FloatTensor ()
27952795 torch .randperm (50000 , out = res2 , device = cuda )
27962796 self .assertEqual (res1 , res2 , 0 )
27972797
@@ -2802,6 +2802,14 @@ def test_randperm_cuda(self):
28022802 self .assertEqual (res1 .numel (), 0 )
28032803 self .assertEqual (res2 .numel (), 0 )
28042804
2805+ # Test exceptions when n is too large for a floating point type
2806+ for res , small_n , large_n in ((torch .cuda .HalfTensor (), 2 ** 11 + 1 , 2 ** 11 + 2 ),
2807+ (torch .cuda .FloatTensor (), 2 ** 24 + 1 , 2 ** 24 + 2 ),
2808+ (torch .cuda .DoubleTensor (), 2 ** 25 , # 2**53 + 1 is too large to run
2809+ 2 ** 53 + 2 )):
2810+ torch .randperm (small_n , out = res ) # No exception expected
2811+ self .assertRaises (RuntimeError , lambda : torch .randperm (large_n , out = res ))
2812+
28052813 def test_random_neg_values (self ):
28062814 _TestTorchMixin ._test_random_neg_values (self , use_cuda = True )
28072815
0 commit comments