@@ -19,11 +19,7 @@ static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor
1919 p.squeeze_ (0 );
2020 lu.squeeze_ (0 );
2121 int int_info = info.squeeze_ ().toCInt ();
22- if (int_info < 0 ) {
23- std::ostringstream ss;
24- ss << " LU factorization (getrf) failed with info = " << int_info;
25- throw std::runtime_error (ss.str ());
26- }
22+ AT_CHECK (int_info >= 0 , " LU factorization (getrf) failed with info = " , int_info);
2723 auto n = self.size (0 );
2824 auto num_exchanges = (at::arange (1 , n + 1 , p.type ()) != p).nonzero ().size (0 );
2925 if (num_exchanges % 2 == 1 ) {
@@ -34,13 +30,10 @@ static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor
3430}
3531
3632Tensor det (const Tensor& self) {
37- if (!at::isFloatingType (self.type ().scalarType ()) ||
38- self.dim () != 2 || self.size (0 ) != self.size (1 )) {
39- std::ostringstream ss;
40- ss << " det(" << self.type () << " {" << self.sizes () << " }): expected a 2D "
41- << " square tensor of floating types" ;
42- throw std::runtime_error (ss.str ());
43- }
33+ AT_CHECK (at::isFloatingType (self.type ().scalarType ()) &&
34+ self.dim () == 2 && self.size (0 ) == self.size (1 ),
35+ " det(" , self.type (), " {" , self.sizes (), " }): expected a 2D square tensor "
36+ " of floating types" );
4437 double det_P;
4538 Tensor diag_U;
4639 int info;
@@ -53,13 +46,10 @@ Tensor det(const Tensor& self) {
5346}
5447
5548Tensor logdet (const Tensor& self) {
56- if (!at::isFloatingType (self.type ().scalarType ()) ||
57- self.dim () != 2 || self.size (0 ) != self.size (1 )) {
58- std::ostringstream ss;
59- ss << " logdet(" << self.type () << " {" << self.sizes () << " }): expected a "
60- << " 2D square tensor of floating types" ;
61- throw std::runtime_error (ss.str ());
62- }
49+ AT_CHECK (at::isFloatingType (self.type ().scalarType ()) &&
50+ self.dim () == 2 && self.size (0 ) == self.size (1 ),
51+ " logdet(" , self.type (), " {" , self.sizes (), " }): expected a 2D square tensor "
52+ " of floating types" );
6353 double det_P;
6454 Tensor diag_U, det;
6555 int info;
@@ -77,13 +67,10 @@ Tensor logdet(const Tensor& self) {
7767}
7868
7969std::tuple<Tensor, Tensor> slogdet (const Tensor& self) {
80- if (!at::isFloatingType (self.type ().scalarType ()) ||
81- self.dim () != 2 || self.size (0 ) != self.size (1 )) {
82- std::ostringstream ss;
83- ss << " slogdet(" << self.type () << " {" << self.sizes () << " }): expected a "
84- << " 2D square tensor of floating types" ;
85- throw std::runtime_error (ss.str ());
86- }
70+ AT_CHECK (at::isFloatingType (self.type ().scalarType ()) &&
71+ self.dim () == 2 && self.size (0 ) == self.size (1 ),
72+ " slogdet(" , self.type (), " {" , self.sizes (), " }): expected a 2D square tensor "
73+ " of floating types" );
8774 double det_P;
8875 Tensor diag_U, det;
8976 int info;
@@ -96,10 +83,19 @@ std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
9683 return std::make_tuple (det.sign (), diag_U.abs_ ().log_ ().sum ());
9784}
9885
86+ Tensor pinverse (const Tensor& self, double rcond) {
87+ AT_CHECK (at::isFloatingType (self.type ().scalarType ()) && self.dim () == 2 ,
88+ " pinverse(" , self.type (), " {" , self.sizes (), " }): expected a 2D tensor "
89+ " of floating types" );
90+ Tensor U, S, V;
91+ std::tie (U, S, V) = self.svd ();
92+ double max_val = S[0 ].toCDouble ();
93+ Tensor S_pseudoinv = at::where (S > rcond * max_val, S.reciprocal (), at::zeros ({}, self.options ()));
94+ return V.mm (S_pseudoinv.diag ().mm (U.t ()));
95+ }
96+
9997static void check_1d (const Tensor& t, const char * arg, const char * fn) {
100- if (t.dim () != 1 ) {
101- AT_ERROR (fn, " : Expected 1-D argument " , arg, " , but got " , t.dim (), " -D" );
102- }
98+ AT_CHECK (t.dim () == 1 , fn, " : Expected 1-D argument " , arg, " , but got " , t.dim (), " -D" );
10399}
104100
105101Tensor ger (const Tensor& self, const Tensor& vec2) {
0 commit comments