@@ -1898,6 +1898,30 @@ def _make_cov(S):
18981898 return torch .mm (L , L .t ())
18991899
19001900
1901+ def random_square_matrix_of_rank (l , rank ):
1902+ assert rank <= l
1903+ A = torch .randn (l , l )
1904+ u , s , v = A .svd ()
1905+ for i in range (l ):
1906+ if i >= rank :
1907+ s [i ] = 0
1908+ elif s [i ] == 0 :
1909+ s [i ] = 1
1910+ return u .mm (torch .diag (s )).mm (v .transpose (0 , 1 ))
1911+
1912+
1913+ def random_symmetric_matrix (l ):
1914+ A = torch .randn (l , l )
1915+ return A .mm (A .transpose (0 , 1 ))
1916+
1917+
1918+ def random_fullrank_matrix_distinct_singular_value (l ):
1919+ A = torch .randn (l , l )
1920+ u , _ , v = A .svd ()
1921+ s = torch .arange (1 , l + 1 ).mul_ (1.0 / (l + 1 ))
1922+ return u .mm (torch .diag (s )).mm (v .transpose (0 , 1 ))
1923+
1924+
19011925class dont_convert (tuple ):
19021926 pass
19031927
@@ -1906,7 +1930,6 @@ class dont_convert(tuple):
19061930M = 10
19071931S = 5
19081932
1909-
19101933# (name, size, args...)
19111934method_tests = [
19121935 ('add' , (S , S , S ), ((S , S , S ),)),
@@ -2166,6 +2189,13 @@ class dont_convert(tuple):
21662189 ('index_copy' , (S , S ), (0 , index_perm_variable (2 , S ), (2 , S )), 'dim' , [0 ]),
21672190 ('index_fill' , (S , S ), (0 , index_variable (2 , S ), 2 ), 'dim' , [0 ]),
21682191 ('inverse' , (S , S ), (), '' , (), [skipIfNoLapack ]),
2192+ ('det' , (S , S ), (), '' , (), [skipIfNoLapack ]),
2193+ ('det' , lambda : random_symmetric_matrix (S ), (), 'symmetric' , (), [skipIfNoLapack ]),
2194+ ('det' , lambda : random_square_matrix_of_rank (S , S - 2 ), (), 'dim2_null' , (), [skipIfNoLapack ]),
2195+ ('det' , lambda : random_square_matrix_of_rank (S , 1 ), (), 'rank1' , (), [skipIfNoLapack ]),
2196+ ('det' , lambda : random_square_matrix_of_rank (S , 2 ), (), 'rank2' , (), [skipIfNoLapack ]),
2197+ ('det' , lambda : random_fullrank_matrix_distinct_singular_value (S ), (), 'distinct_postive_s' , (), [skipIfNoLapack ]),
2198+ ('svd' , lambda : random_fullrank_matrix_distinct_singular_value (S ), (), '' , (), [skipIfNoLapack ]),
21692199 ('gesv' , (S , S ), ((S , S ),), '' , (), [skipIfNoLapack ]),
21702200 ('potrf' , _make_cov (S ), (True ,), '' , (), [skipIfNoLapack ]),
21712201 ('eq' , (S , S , S ), ((S , S , S ),)),
@@ -2303,6 +2333,8 @@ def maybe_non_contig(tensor):
23032333 return Variable (maybe_non_contig (arg ), requires_grad = requires_grad )
23042334 elif isinstance (arg , Variable ) and non_contiguous :
23052335 return Variable (maybe_non_contig (arg .data ), requires_grad = arg .requires_grad )
2336+ elif callable (arg ):
2337+ return map_arg (arg ())
23062338 else :
23072339 return arg
23082340 return tuple (map_arg (arg ) for arg in call_args )
@@ -2339,6 +2371,19 @@ def maybe_non_contig(tensor):
23392371EXCLUDE_GRADCHECK = {
23402372 'potrf'
23412373}
2374+ EXCLUDE_GRADGRADCHECK = {
2375+ 'svd'
2376+ }
2377+ EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = {
2378+ # Some of the following det ones pass because random matrix has full rank
2379+ # with high probability. But we can't rely on this. So only test gradgrad on
2380+ # test_det_distinct_postive_s.
2381+ 'test_det' ,
2382+ 'test_det_symmetric' ,
2383+ 'test_det_dim2_null' ,
2384+ 'test_det_rank1' ,
2385+ 'test_det_rank2'
2386+ }
23422387
23432388
23442389def exclude_tensor_method (name , test_name ):
@@ -2359,6 +2404,7 @@ def exclude_tensor_method(name, test_name):
23592404 'resize_as' ,
23602405 'scatter' ,
23612406 'scatter_add' ,
2407+ 'det' ,
23622408 }
23632409 if test_name in exclude_all_tensor_method_by_test_name :
23642410 return True
@@ -2390,17 +2436,19 @@ def gradgradcheck_method_precision_override(test_name):
23902436 return override
23912437
23922438
2393- def run_grad_and_gradgrad_checks (test_case , test_name , apply_method , output_variable , input_variables ):
2439+ def run_grad_and_gradgrad_checks (test_case , name , test_name , apply_method , output_variable ,
2440+ input_variables , run_gradgradcheck = True ):
23942441 test_case .assertTrue (gradcheck (apply_method , input_variables , eps = 1e-6 , atol = PRECISION ))
2395-
2442+ if name in EXCLUDE_GRADGRADCHECK or test_name in EXCLUDE_GRADGRADCHECK_BY_TEST_NAME :
2443+ return
23962444 grad_y = generate_gradoutput (output_variable , non_contiguous = True )
23972445 gradgradcheck_precision_override = gradgradcheck_method_precision_override (test_name )
23982446 if gradgradcheck_precision_override is not None :
23992447 atol = gradgradcheck_precision_override ['atol' ]
24002448 rtol = gradgradcheck_precision_override ['rtol' ]
24012449 test_case .assertTrue (gradgradcheck (apply_method , input_variables , grad_y , atol = atol , rtol = rtol ))
24022450 else :
2403- test_case .assertTrue (gradgradcheck (apply_method , input_variables , grad_y , ))
2451+ test_case .assertTrue (gradgradcheck (apply_method , input_variables , grad_y ))
24042452
24052453
24062454def run_functional_checks (test_case , test_name , name , apply_fn , run_grad_checks ,
@@ -2413,7 +2461,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
24132461 test_case .assertEqual (unpack_variables (output_variable ), output_tensor )
24142462
24152463 if run_grad_checks :
2416- run_grad_and_gradgrad_checks (test_case , test_name , apply_fn ,
2464+ run_grad_and_gradgrad_checks (test_case , name , test_name , apply_fn ,
24172465 output_variable , f_args_variable )
24182466
24192467 self_variable = f_args_variable [0 ]
@@ -2457,7 +2505,7 @@ def check(name):
24572505 # TODO: check that both have changed after adding all inplace ops
24582506
24592507 if not is_inplace and name not in EXCLUDE_GRADCHECK :
2460- run_grad_and_gradgrad_checks (self , test_name ,
2508+ run_grad_and_gradgrad_checks (self , name , test_name ,
24612509 lambda * inputs : getattr (inputs [0 ], name )(* inputs [1 :]),
24622510 output_variable , (self_variable ,) + args_variable )
24632511
0 commit comments