@@ -1903,8 +1903,8 @@ def get_int64_dtype(dtype):
19031903 check_value (torch .empty (shape , out = out , device = device , layout = layout , requires_grad = rg ),
19041904 dtype , layout , device , None , rg )
19051905 check_value (v .new_empty (shape ), dtype , layout , device , None , False )
1906- check_value (v .new_empty (shape , dtype = int64_dtype , device = device , requires_grad = rg ),
1907- int64_dtype , layout , device , None , rg )
1906+ check_value (v .new_empty (shape , dtype = int64_dtype , device = device , requires_grad = False ),
1907+ int64_dtype , layout , device , None , False )
19081908 check_value (torch .empty_like (v ), dtype , layout , device , None , False )
19091909 check_value (torch .empty_like (v , dtype = int64_dtype , layout = layout , device = device , requires_grad = False ),
19101910 int64_dtype , layout , device , None , False )
@@ -1917,8 +1917,8 @@ def get_int64_dtype(dtype):
19171917 out = v .new ()
19181918 check_value (torch .full (shape , fv + 2 , out = out , device = device , layout = layout , requires_grad = rg ),
19191919 dtype , layout , device , fv + 2 , rg )
1920- check_value (v .new_full (shape , fv + 3 , dtype = int64_dtype , device = device , requires_grad = rg ),
1921- int64_dtype , layout , device , fv + 3 , rg )
1920+ check_value (v .new_full (shape , fv + 3 , dtype = int64_dtype , device = device , requires_grad = False ),
1921+ int64_dtype , layout , device , fv + 3 , False )
19221922 check_value (torch .full_like (v , fv + 4 ), dtype , layout , device , fv + 4 , False )
19231923 check_value (torch .full_like (v , fv + 5 ,
19241924 dtype = int64_dtype , layout = layout , device = device , requires_grad = False ),
@@ -2697,12 +2697,12 @@ def test_contiguous(self):
26972697
26982698 def test_scalars_as_floats (self ):
26992699 "zero-dim variables that don't require grad should bind to scalar arguments"
2700- x = torch .tensor (2 )
2701- y = torch .tensor (3 )
2700+ x = torch .tensor (2. )
2701+ y = torch .tensor (3. )
27022702 # 3 + (3 * 3) * 2
27032703 self .assertEqual (y .addcmul (y , y , value = x ), 21 )
27042704
2705- x = torch .tensor (2 , requires_grad = True )
2705+ x = torch .tensor (2. , requires_grad = True )
27062706 self .assertRaises (Exception , lambda : y .addcmul (y , y , value = x ))
27072707
27082708 @staticmethod
@@ -6123,8 +6123,6 @@ def test_parsing_int64(self):
61236123 self .assertEqual (x , torch .cumsum (torch .ones (5 , 5 ), torch .tensor (0 )))
61246124 # doesn't accept floating point variables
61256125 self .assertRaises (TypeError , lambda : torch .cumsum (torch .ones (5 , 5 ), torch .tensor (0. )))
6126- # doesn't accept variables with requires_grad
6127- self .assertRaises (TypeError , lambda : torch .cumsum (torch .ones (5 , 5 ), torch .tensor (0 , requires_grad = True )))
61286126
61296127 def test_parsing_double (self ):
61306128 # accepts floating point and integer arguments
@@ -6136,8 +6134,6 @@ def test_parsing_double(self):
61366134 self .assertTrue (torch .isclose (x , x , torch .tensor (1 ), torch .tensor (1 )).all ())
61376135 self .assertTrue (torch .isclose (x , x , torch .tensor (1.5 ), torch .tensor (1. )).all ())
61386136 # doesn't accept variables with requires_grad
6139- self .assertRaises (TypeError ,
6140- lambda : torch .isclose (x , x , torch .tensor (1 , requires_grad = True ), torch .tensor (1 )).all ())
61416137 self .assertRaises (TypeError ,
61426138 lambda : torch .isclose (x , x , torch .tensor (1.5 ), torch .tensor (1. , requires_grad = True )).all ())
61436139
0 commit comments