@@ -1915,6 +1915,13 @@ def random_symmetric_matrix(l):
19151915 return A .mm (A .transpose (0 , 1 ))
19161916
19171917
1918+ def random_fullrank_matrix_distinct_singular_value (l ):
1919+ A = torch .randn (l , l )
1920+ u , _ , v = A .svd ()
1921+ s = torch .arange (1 , l + 1 ).mul_ (1.0 / (l + 1 ))
1922+ return u .mm (torch .diag (s )).mm (v .transpose (0 , 1 ))
1923+
1924+
19181925class dont_convert (tuple ):
19191926 pass
19201927
@@ -2187,6 +2194,8 @@ class dont_convert(tuple):
21872194 ('det' , lambda : random_square_matrix_of_rank (S , S - 2 ), (), 'dim2_null' , (), [skipIfNoLapack ]),
21882195 ('det' , lambda : random_square_matrix_of_rank (S , 1 ), (), 'rank1' , (), [skipIfNoLapack ]),
21892196 ('det' , lambda : random_square_matrix_of_rank (S , 2 ), (), 'rank2' , (), [skipIfNoLapack ]),
2197+ ('det' , lambda : random_fullrank_matrix_distinct_singular_value (S ), (), 'distinct_postive_s' , (), [skipIfNoLapack ]),
2198+ ('svd' , lambda : random_fullrank_matrix_distinct_singular_value (S ), (), '' , (), [skipIfNoLapack ]),
21902199 ('gesv' , (S , S ), ((S , S ),), '' , (), [skipIfNoLapack ]),
21912200 ('potrf' , _make_cov (S ), (True ,), '' , (), [skipIfNoLapack ]),
21922201 ('eq' , (S , S , S ), ((S , S , S ),)),
@@ -2363,7 +2372,17 @@ def maybe_non_contig(tensor):
23632372 'potrf'
23642373}
23652374EXCLUDE_GRADGRADCHECK = {
2366- 'det'
2375+ 'svd'
2376+ }
2377+ EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = {
2378+ # Some of the following det ones pass because random matrix has full rank
2379+ # with high probability. But we can't rely on this. So only test gradgrad on
2380+ # test_det_distinct_postive_s.
2381+ 'test_det' ,
2382+ 'test_det_symmetric' ,
2383+ 'test_det_dim2_null' ,
2384+ 'test_det_rank1' ,
2385+ 'test_det_rank2'
23672386}
23682387
23692388
@@ -2417,10 +2436,10 @@ def gradgradcheck_method_precision_override(test_name):
24172436 return override
24182437
24192438
2420- def run_grad_and_gradgrad_checks (test_case , test_name , apply_method , output_variable ,
2439+ def run_grad_and_gradgrad_checks (test_case , name , test_name , apply_method , output_variable ,
24212440 input_variables , run_gradgradcheck = True ):
24222441 test_case .assertTrue (gradcheck (apply_method , input_variables , eps = 1e-6 , atol = PRECISION ))
2423- if not run_gradgradcheck :
2442+ if name in EXCLUDE_GRADGRADCHECK or test_name in EXCLUDE_GRADGRADCHECK_BY_TEST_NAME :
24242443 return
24252444 grad_y = generate_gradoutput (output_variable , non_contiguous = True )
24262445 gradgradcheck_precision_override = gradgradcheck_method_precision_override (test_name )
@@ -2442,7 +2461,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
24422461 test_case .assertEqual (unpack_variables (output_variable ), output_tensor )
24432462
24442463 if run_grad_checks :
2445- run_grad_and_gradgrad_checks (test_case , test_name , apply_fn ,
2464+ run_grad_and_gradgrad_checks (test_case , name , test_name , apply_fn ,
24462465 output_variable , f_args_variable )
24472466
24482467 self_variable = f_args_variable [0 ]
@@ -2486,10 +2505,9 @@ def check(name):
24862505 # TODO: check that both have changed after adding all inplace ops
24872506
24882507 if not is_inplace and name not in EXCLUDE_GRADCHECK :
2489- run_grad_and_gradgrad_checks (self , test_name ,
2508+ run_grad_and_gradgrad_checks (self , name , test_name ,
24902509 lambda * inputs : getattr (inputs [0 ], name )(* inputs [1 :]),
2491- output_variable , (self_variable ,) + args_variable ,
2492- name not in EXCLUDE_GRADGRADCHECK )
2510+ output_variable , (self_variable ,) + args_variable )
24932511
24942512 # functional interface tests
24952513 if hasattr (torch , name ) and name not in EXCLUDE_FUNCTIONAL :
0 commit comments