11#include " ATen/ATen.h"
22#include " ATen/CPUApplyUtils.h"
33#include " ATen/Dispatch.h"
4- #include " ATen/ExpandUtils.h"
54#include " ATen/NativeFunctions.h"
65
76#include " ATen/native/LinearAlgebraUtils.h"
1615#ifdef USE_LAPACK
1716
1817// gesv
19- extern " C" void dgesv_ (int * n, int * nrhs, double * a, int * lda, int *ipiv, double * b, int * ldb, int * info);
20- extern " C" void sgesv_ (int * n, int * nrhs, float * a, int * lda, int * ipiv, float * b, int * ldb, int * info);
18+ extern " C" void dgesv_ (int * n, int * nrhs, double * a, int * lda, int *ipiv, double * b, int * ldb, int * info);
19+ extern " C" void sgesv_ (int * n, int * nrhs, float * a, int * lda, int * ipiv, float * b, int * ldb, int * info);
2120
2221// inverse
2322extern " C" void dgetrf_ (int *m, int *n, double *a, int *lda, int *ipiv, int *info);
2423extern " C" void sgetrf_ (int *m, int *n, float *a, int *lda, int *ipiv, int *info);
2524extern " C" void dgetri_ (int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info);
2625extern " C" void sgetri_ (int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);
26+
27+ // potrs
28+ extern " C" void dpotrs_ (char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
29+ extern " C" void spotrs_ (char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
2730#endif
2831
2932namespace at {
@@ -32,12 +35,12 @@ namespace native {
3235// Define the per-batch functions to be used in the main implementation of the batched
3336// linear algebra operations
3437template <class scalar_t >
35- void lapackGesv (int n, int nrhs, scalar_t * a, int lda, int * ipiv, scalar_t * b, int ldb, int * info) {
38+ void lapackGesv (int n, int nrhs, scalar_t * a, int lda, int * ipiv, scalar_t * b, int ldb, int * info) {
3639 AT_ERROR (" gesv only takes float or double Tensors" );
3740}
3841
3942template <class scalar_t >
40- void lapackGetrf (int m, int n, scalar_t * a, int lda, int *ipiv, int *info) {
43+ void lapackGetrf (int m, int n, scalar_t * a, int lda, int *ipiv, int *info) {
4144 AT_ERROR (" getrf only takes float or double Tensors" );
4245}
4346
@@ -46,12 +49,17 @@ void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwo
4649 AT_ERROR (" getri only takes float or double Tensors" );
4750}
4851
52+ template <class scalar_t >
53+ void lapackPotrs (char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
54+ AT_ERROR (" potrs only takes float or double Tensors" );
55+ }
56+
4957#ifdef USE_LAPACK
50- template <> void lapackGesv<double >(int n, int nrhs, double * a, int lda, int * ipiv, double * b, int ldb, int * info) {
58+ template <> void lapackGesv<double >(int n, int nrhs, double * a, int lda, int * ipiv, double * b, int ldb, int * info) {
5159 dgesv_ (&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
5260}
5361
54- template <> void lapackGesv<float >(int n, int nrhs, float * a, int lda, int * ipiv, float * b, int ldb, int * info) {
62+ template <> void lapackGesv<float >(int n, int nrhs, float * a, int lda, int * ipiv, float * b, int ldb, int * info) {
5563 sgesv_ (&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
5664}
5765
@@ -70,6 +78,14 @@ template<> void lapackGetrf<double>(int m, int n, double *a, int lda, int *ipiv,
7078template <> void lapackGetrf<float >(int m, int n, float *a, int lda, int *ipiv, int *info) {
7179 sgetrf_ (&m, &n, a, &lda, ipiv, info);
7280}
81+
82+ template <> void lapackPotrs<double >(char uplo, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
83+ dpotrs_ (&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
84+ }
85+
86+ template <> void lapackPotrs<float >(char uplo, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
87+ spotrs_ (&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
88+ }
7389#endif
7490
7591// Below of the definitions of the functions operating on a batch that are going to be dispatched
@@ -105,8 +121,16 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
105121 }
106122}
107123
108- // These utilities are specified in LinearAlgebraUtils.h
109- GENERATE_LINALG_HELPER_2_ARGS (gesv, self, A, cpu)
124+ std::tuple<Tensor, Tensor> _gesv_helper_cpu (const Tensor& self, const Tensor& A) {
125+ std::vector<int64_t > infos (batchCount (self), 0 );
126+ auto self_working_copy = cloneBatchedColumnMajor (self);
127+ auto A_working_copy = cloneBatchedColumnMajor (A);
128+ AT_DISPATCH_FLOATING_TYPES (self.type (), " gesv" , [&]{
129+ apply_gesv<scalar_t >(self_working_copy, A_working_copy, infos);
130+ });
131+ batchCheckErrors (infos, " gesv" );
132+ return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
133+ }
110134
111135// Supports arbitrary batch dimensions for self and A
112136std::tuple<Tensor,Tensor> gesv (const Tensor& self, const Tensor& A) {
@@ -117,21 +141,8 @@ std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
117141 return at::_th_gesv_single (self, A);
118142 }
119143
120- gesvCheckInputs (self, A);
121-
122- // broadcast the batch dimensions of self and A.
123- IntList self_batch_sizes (self.sizes ().data (), self.ndimension () - 2 );
124- IntList A_batch_sizes (A.sizes ().data (), A.ndimension () - 2 );
125- std::vector<int64_t > expand_batch_portion = infer_size (self_batch_sizes, A_batch_sizes);
126-
127- std::vector<int64_t > self_expand_size ({expand_batch_portion});
128- self_expand_size.insert (self_expand_size.end (), { self.size (-2 ), self.size (-1 ) });
129-
130- std::vector<int64_t > A_expand_size ({expand_batch_portion});
131- A_expand_size.insert (A_expand_size.end (), { A.size (-2 ), A.size (-1 ) });
132-
133- Tensor self_broadcasted = self.expand (self_expand_size);
134- Tensor A_broadcasted = A.expand (A_expand_size);
144+ Tensor self_broadcasted, A_broadcasted;
145+ std::tie (self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args (self, A);
135146 return at::_gesv_helper (self_broadcasted, A_broadcasted);
136147}
137148
@@ -185,7 +196,15 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
185196 }
186197}
187198
188- GENERATE_LINALG_HELPER_1_ARGS (inverse, self, cpu)
199+ Tensor _inverse_helper_cpu (const Tensor& self) {
200+ std::vector<int64_t > infos (batchCount (self), 0 );
201+ auto self_working_copy = cloneBatchedColumnMajor (self);
202+ AT_DISPATCH_FLOATING_TYPES (self.type (), " inverse" , [&]{
203+ apply_inverse<scalar_t >(self_working_copy, infos);
204+ });
205+ batchCheckErrors (infos, " inverse" );
206+ return self_working_copy;
207+ }
189208
190209Tensor inverse (const Tensor &self) {
191210 if (self.size (-1 ) == 0 ) {
@@ -206,4 +225,63 @@ Tensor& inverse_out(Tensor &result, const Tensor &self) {
206225 return result;
207226}
208227
228+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ potrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229+
230+ template <typename scalar_t >
231+ static void apply_potrs (Tensor& b, Tensor& A, bool upper, std::vector<int64_t >& infos) {
232+ #ifndef USE_LAPACK
233+ AT_ERROR (" potrs: LAPACK library not found in compilation" );
234+ #endif
235+ char uplo = upper ? ' U' : ' L' ;
236+
237+ auto A_data = A.data <scalar_t >();
238+ auto b_data = b.data <scalar_t >();
239+ auto A_mat_stride = matrixStride (A);
240+ auto b_mat_stride = matrixStride (b);
241+
242+ auto batch_size = batchCount (A);
243+ auto n = A.size (-2 );
244+ auto nrhs = b.size (-1 );
245+
246+ for (int64_t i = 0 ; i < batch_size; i++) {
247+ int info;
248+ scalar_t * A_working_ptr = &A_data[i * A_mat_stride];
249+ scalar_t * b_working_ptr = &b_data[i * b_mat_stride];
250+ lapackPotrs<scalar_t >(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
251+ infos[i] = info;
252+ if (info != 0 ) {
253+ return ;
254+ }
255+ }
256+ }
257+
258+ Tensor _potrs_helper_cpu (const Tensor& self, const Tensor& A, bool upper) {
259+ std::vector<int64_t > infos (batchCount (self), 0 );
260+ auto self_working_copy = cloneBatchedColumnMajor (self);
261+ auto A_working_copy = cloneBatchedColumnMajor (A);
262+ AT_DISPATCH_FLOATING_TYPES (self.type (), " potrs" , [&]{
263+ apply_potrs<scalar_t >(self_working_copy, A_working_copy, upper, infos);
264+ });
265+ batchCheckErrors (infos, " potrs" );
266+ return self_working_copy;
267+ }
268+
269+ // Supports arbitrary batch dimensions for self and A
270+ Tensor potrs (const Tensor& self, const Tensor& A, bool upper) {
271+ if (self.dim () <= 2 && A.dim () <= 2 ) {
272+ return at::_th_potrs_single (self, A, upper);
273+ }
274+
275+ Tensor self_broadcasted, A_broadcasted;
276+ std::tie (self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args (self, A);
277+ return at::_potrs_helper (self_broadcasted, A_broadcasted, upper);
278+ }
279+
280+ Tensor& potrs_out (Tensor& result, const Tensor& self, const Tensor& A, bool upper) {
281+ AT_CHECK (self.dim () == 2 && A.dim () == 2 ,
282+ " torch.potrs() with the `out` keyword does not support batching. "
283+ " b.dim() (" , self.dim (), " ) and A.dim() (" , A.dim (), " ) must both be 2." );
284+ return at::_th_potrs_single_out (result, self, A, upper);
285+ }
286+
209287}} // namespace at::native
0 commit comments