@@ -637,6 +637,13 @@ def method_tests():
637637 ('det' , lambda : random_square_matrix_of_rank (S , 2 ), NO_ARGS , 'rank2' , (), NO_ARGS , [skipIfNoLapack ]),
638638 ('det' , lambda : random_fullrank_matrix_distinct_singular_value (S ), NO_ARGS ,
639639 'distinct_singular_values' , (), NO_ARGS , [skipIfNoLapack ]),
640+ ('det' , (3 , 3 , S , S ), NO_ARGS , 'batched' , (), NO_ARGS , [skipIfNoLapack ]),
641+ ('det' , (3 , 3 , 1 , 1 ), NO_ARGS , 'batched_1x1' , (), NO_ARGS , [skipIfNoLapack ]),
642+ ('det' , lambda : random_symmetric_matrix (S , 3 ), NO_ARGS , 'batched_symmetric' , (), NO_ARGS , [skipIfNoLapack ]),
643+ ('det' , lambda : random_symmetric_psd_matrix (S , 3 ), NO_ARGS , 'batched_symmetric_psd' , (), NO_ARGS , [skipIfNoLapack ]),
644+ ('det' , lambda : random_symmetric_pd_matrix (S , 3 ), NO_ARGS , 'batched_symmetric_pd' , (), NO_ARGS , [skipIfNoLapack ]),
645+ ('det' , lambda : random_fullrank_matrix_distinct_singular_value (S , 3 , 3 ), NO_ARGS ,
646+ 'batched_distinct_singular_values' , (), NO_ARGS , [skipIfNoLapack ]),
640647 # For `logdet` and `slogdet`, the function at det=0 is not smooth.
641648 # We need to exclude tests with det=0 (e.g. dim2_null, rank1, rank2) and use
642649 # `make_nonzero_det` to make the random matrices have nonzero det. For
@@ -650,6 +657,14 @@ def method_tests():
650657 'symmetric_pd' , (), NO_ARGS , [skipIfNoLapack ]),
651658 ('logdet' , lambda : make_nonzero_det (random_fullrank_matrix_distinct_singular_value (S ), 1 , 0 ), NO_ARGS ,
652659 'distinct_singular_values' , (), NO_ARGS , [skipIfNoLapack ]),
660+ ('logdet' , lambda : make_nonzero_det (torch .randn (3 , 3 , S , S ), 1 ), NO_ARGS , 'batched' , (), NO_ARGS , [skipIfNoLapack ]),
661+ ('logdet' , lambda : make_nonzero_det (torch .randn (3 , 3 , 1 , 1 ), 1 ), NO_ARGS , 'batched_1x1' , (), NO_ARGS , [skipIfNoLapack ]),
662+ ('logdet' , lambda : make_nonzero_det (random_symmetric_matrix (S , 3 ), 1 ), NO_ARGS ,
663+ 'batched_symmetric' , (), NO_ARGS , [skipIfNoLapack ]),
664+ ('logdet' , lambda : make_nonzero_det (random_symmetric_pd_matrix (S , 3 ), 1 ), NO_ARGS ,
665+ 'batched_symmetric_pd' , (), NO_ARGS , [skipIfNoLapack ]),
666+ ('logdet' , lambda : make_nonzero_det (random_fullrank_matrix_distinct_singular_value (S , 3 ), 1 , 0 ), NO_ARGS ,
667+ 'batched_distinct_singular_values' , (), NO_ARGS , [skipIfNoLapack ]),
653668 ('slogdet' , lambda : make_nonzero_det (torch .randn (1 , 1 ), 1 ), NO_ARGS ,
654669 '1x1_pos_det' , (), NO_ARGS , [skipIfNoLapack ], itemgetter (1 )),
655670 ('slogdet' , lambda : make_nonzero_det (torch .randn (1 , 1 ), - 1 ), NO_ARGS ,
@@ -664,6 +679,16 @@ def method_tests():
664679 'symmetric_pd' , (), NO_ARGS , [skipIfNoLapack ], itemgetter (1 )),
665680 ('slogdet' , lambda : random_fullrank_matrix_distinct_singular_value (S ), NO_ARGS ,
666681 'distinct_singular_values' , (), NO_ARGS , [skipIfNoLapack ], itemgetter (1 )),
682+ ('slogdet' , lambda : make_nonzero_det (torch .randn (3 , 3 , 1 , 1 ), - 1 ), NO_ARGS ,
683+ 'batched_1x1_neg_det' , (), NO_ARGS , [skipIfNoLapack ], itemgetter (1 )),
684+ ('slogdet' , lambda : make_nonzero_det (torch .randn (3 , 3 , S , S ), 1 ), NO_ARGS ,
685+ 'batched_pos_det' , (), NO_ARGS , [skipIfNoLapack ], itemgetter (1 )),
686+ ('slogdet' , lambda : make_nonzero_det (random_symmetric_matrix (S , 3 )), NO_ARGS ,
687+ 'batched_symmetric' , (), NO_ARGS , [skipIfNoLapack ], itemgetter (1 )),
688+ ('slogdet' , lambda : random_symmetric_pd_matrix (S , 3 ), NO_ARGS ,
689+ 'batched_symmetric_pd' , (), NO_ARGS , [skipIfNoLapack ], itemgetter (1 )),
690+ ('slogdet' , lambda : random_fullrank_matrix_distinct_singular_value (S , 3 ), NO_ARGS ,
691+ 'batched_distinct_singular_values' , (), NO_ARGS , [skipIfNoLapack ], itemgetter (1 )),
667692 ('symeig' , lambda : random_symmetric_matrix (S ), (True , False ), 'lower' , (), NO_ARGS , [skipIfNoLapack ]),
668693 ('symeig' , lambda : random_symmetric_matrix (S ), (True , True ), 'upper' , (), NO_ARGS , [skipIfNoLapack ]),
669694 ('symeig' , lambda : random_symmetric_matrix (M ), (True , True ), 'large' , (), NO_ARGS , [skipIfNoLapack ]),
@@ -1082,14 +1107,23 @@ def unpack_variables(args):
10821107 'test_det_dim2_null' ,
10831108 'test_det_rank1' ,
10841109 'test_det_rank2' ,
1110+ 'test_det_batched' ,
1111+ 'test_det_batched_1x1' ,
1112+ 'test_det_batched_symmetric' ,
1113+ 'test_det_batched_symmetric_psd' ,
10851114 # `other` expand_as(self, other) is not used in autograd.
10861115 'test_expand_as' ,
10871116 'test_logdet' ,
10881117 'test_logdet_1x1' ,
10891118 'test_logdet_symmetric' ,
1119+ 'test_logdet_batched' ,
1120+ 'test_logdet_batched_1x1' ,
1121+ 'test_logdet_batched_symmetric' ,
10901122 'test_slogdet_1x1_neg_det' ,
10911123 'test_slogdet_neg_det' ,
10921124 'test_slogdet_symmetric' ,
1125+ 'test_slogdet_batched_1x1_neg_det' ,
1126+ 'test_slogdet_batched_symmetric' ,
10931127 'test_cdist' ,
10941128}
10951129
0 commit comments