Skip to content

Commit bcb5fd8

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Port symeig to ATen and enable batching of inputs (#21858)
Summary: Changelog: - Port `symeig` from TH/THC to ATen - Enable batching of matrix inputs for `symeig` - Modify derivative computation based on batching - Update docs to reflect the change Pull Request resolved: #21858 Test Plan: - Added additional tests in `test_torch.py` (with a port to `test_cuda.py`) and `common_methods_invocations.py` to test if both the port and batching work. Differential Revision: D15981789 Pulled By: soumith fbshipit-source-id: ab9af8361f8608db42318aabc8421bd99a1ca7ae
1 parent 4ec6fbe commit bcb5fd8

17 files changed

+319
-239
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2323,33 +2323,6 @@
23232323
- THTensor* self
23242324
- THTensor* A
23252325
]]
2326-
[[
2327-
name: _th_symeig
2328-
cname: syev
2329-
types:
2330-
- Float
2331-
- Double
2332-
backends:
2333-
- CPU
2334-
- CUDA
2335-
variants:
2336-
- function
2337-
return: argument 0,1
2338-
arguments:
2339-
- arg: THTensor* res1
2340-
output: True
2341-
- arg: THTensor* res2
2342-
output: True
2343-
- THTensor* self
2344-
- arg: bool eigenvectors
2345-
if_true: V
2346-
if_false: N
2347-
default: N
2348-
- arg: bool upper
2349-
if_true: U
2350-
if_false: L
2351-
default: U
2352-
]]
23532326
[[
23542327
name: _th_eig
23552328
cname: geev

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ extern "C" void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *w
4646
// orgqr
4747
extern "C" void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
4848
extern "C" void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
49+
50+
// syev
51+
extern "C" void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
52+
extern "C" void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);
4953
#endif
5054

5155
namespace at {
@@ -93,6 +97,11 @@ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scala
9397
AT_ERROR("orgqr only takes float or double Tensors");
9498
}
9599

100+
template<class scalar_t>
101+
void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, scalar_t *w, scalar_t *work, int lwork, int *info) {
102+
AT_ERROR("symeig only takes float or double Tensors");
103+
}
104+
96105
#ifdef USE_LAPACK
97106
template<> void lapackSolve<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
98107
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
@@ -157,6 +166,14 @@ template<> void lapackOrgqr<double>(int m, int n, int k, double *a, int lda, dou
157166
template<> void lapackOrgqr<float>(int m, int n, int k, float *a, int lda, float *tau, float *work, int lwork, int *info) {
158167
sorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info);
159168
}
169+
170+
template<> void lapackSymeig<double>(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, int *info) {
171+
dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
172+
}
173+
174+
template<> void lapackSymeig<float>(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, int *info) {
175+
ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
176+
}
160177
#endif
161178

162179
// Below of the definitions of the functions operating on a batch that are going to be dispatched
@@ -833,4 +850,87 @@ std::tuple<Tensor&,Tensor&> qr_out(Tensor& Q, Tensor& R, const Tensor& self, boo
833850
return std::tuple<Tensor&, Tensor&>(Q, R);
834851
}
835852

853+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
854+
855+
template <typename scalar_t>
856+
static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool upper, std::vector<int64_t>& infos) {
857+
#ifndef USE_LAPACK
858+
AT_ERROR("symeig: LAPACK library not found in compilation");
859+
#else
860+
auto self_data = self.data<scalar_t>();
861+
auto eigvals_data = eigvals.data<scalar_t>();
862+
auto self_matrix_stride = matrixStride(self);
863+
auto eigvals_stride = eigvals.size(-1);
864+
auto batch_size = batchCount(self);
865+
auto n = self.size(-1);
866+
867+
char uplo = upper ? 'U' : 'L';
868+
char jobz = eigenvectors ? 'V' : 'N';
869+
870+
int info;
871+
// Run once, first to get the optimum work size.
872+
// Since we deal with batches of matrices with the same dimensions, doing this outside
873+
// the loop saves (batch_size - 1) workspace queries which would provide the same result
874+
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
875+
int lwork = -1;
876+
scalar_t wkopt;
877+
lapackSymeig<scalar_t>(jobz, uplo, n, self_data, n, eigvals_data, &wkopt, lwork, &info);
878+
lwork = static_cast<int>(wkopt);
879+
Tensor work = at::empty({lwork}, self.options());
880+
881+
for (int64_t i = 0; i < batch_size; i++) {
882+
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
883+
scalar_t* eigvals_working_ptr = &eigvals_data[i * eigvals_stride];
884+
885+
// now compute the eigenvalues and the eigenvectors (optionally)
886+
lapackSymeig<scalar_t>(jobz, uplo, n, self_working_ptr, n, eigvals_working_ptr, work.data<scalar_t>(), lwork, &info);
887+
infos[i] = info;
888+
if (info != 0) {
889+
return;
890+
}
891+
}
892+
#endif
893+
}
894+
895+
std::tuple<Tensor, Tensor> _symeig_helper_cpu(const Tensor& self, bool eigenvectors, bool upper) {
896+
std::vector<int64_t> infos(batchCount(self), 0);
897+
898+
auto self_sizes = self.sizes().vec();
899+
self_sizes.pop_back();
900+
auto eigvals = at::empty(self_sizes, self.options());
901+
902+
if (self.numel() == 0) {
903+
return std::tuple<Tensor, Tensor>(eigvals, at::empty_like(self));
904+
}
905+
906+
auto self_working_copy = cloneBatchedColumnMajor(self);
907+
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "symeig_cpu", [&]{
908+
apply_symeig<scalar_t>(self_working_copy, eigvals, eigenvectors, upper, infos);
909+
});
910+
911+
if (!eigenvectors) {
912+
self_working_copy.zero_();
913+
}
914+
if (self.dim() > 2) {
915+
batchCheckErrors(infos, "symeig_cpu");
916+
} else {
917+
singleCheckErrors(infos[0], "symeig_cpu");
918+
}
919+
return std::tuple<Tensor, Tensor>(eigvals, self_working_copy);
920+
}
921+
922+
std::tuple<Tensor, Tensor> symeig(const Tensor& self, bool eigenvectors, bool upper) {
923+
squareCheckInputs(self);
924+
return at::_symeig_helper(self, eigenvectors, upper);
925+
}
926+
927+
std::tuple<Tensor&, Tensor&> symeig_out(Tensor& vals, Tensor& vecs, const Tensor& self, bool eigenvectors, bool upper) {
928+
squareCheckInputs(self);
929+
Tensor vals_tmp, vecs_tmp;
930+
std::tie(vals_tmp, vecs_tmp) = at::_symeig_helper(self, eigenvectors, upper);
931+
vals.resize_as_(vals_tmp).copy_(vals_tmp);
932+
vecs.resize_as_(vecs_tmp).copy_(vecs_tmp);
933+
return std::tuple<Tensor&, Tensor&>(vals, vecs);
934+
}
935+
836936
}} // namespace at::native

aten/src/ATen/native/LinearAlgebraUtils.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/TensorUtils.h>
44
#include <limits>
55
#include <sstream>
6+
#include <cstring>
67

78
namespace at { namespace native {
89

@@ -110,7 +111,7 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
110111
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
111112
}
112113

113-
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, lu)
114+
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, lu, symeig)
114115
static inline void squareCheckInputs(const Tensor& self) {
115116
TORCH_CHECK(self.size(-1) == self.size(-2),
116117
"A must be batches of square matrices, "
@@ -145,7 +146,12 @@ static inline void batchCheckErrors(const Tensor& infos, const char* name) {
145146
if (info < 0) {
146147
AT_ERROR(name, ": For batch ", i, ": Argument ", -info, " has illegal value");
147148
} else if (info > 0) {
148-
AT_ERROR(name, ": For batch ", i, ": U(", info, ",", info, ") is zero, singular U.");
149+
if (strstr(name, "symeig")) {
150+
AT_ERROR(name, ": For batch ", i, ": the algorithm failed to converge; ", info,
151+
" off-diagonal elements of an intermediate tridiagonal form did not converge to zero.")
152+
} else {
153+
AT_ERROR(name, ": For batch ", i, ": U(", info, ",", info, ") is zero, singular U.");
154+
}
149155
}
150156
}
151157
}
@@ -158,7 +164,12 @@ static inline void singleCheckErrors(int64_t info, const char* name) {
158164
if (info < 0) {
159165
AT_ERROR(name, ": Argument ", -info, " has illegal value");
160166
} else if (info > 0) {
161-
AT_ERROR(name, ": U(", info, ",", info, ") is zero, singular U.");
167+
if (strstr(name, "symeig")) {
168+
AT_ERROR(name, ": the algorithm failed to converge; ", info,
169+
" off-diagonal elements of an intermediate tridiagonal form did not converge to zero.")
170+
} else {
171+
AT_ERROR(name, ": U(", info, ",", info, ") is zero, singular U.");
172+
}
162173
}
163174
}
164175

aten/src/ATen/native/cuda/BatchLinearAlgebra.cu

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,15 @@ template<class scalar_t>
143143
void magmaOrgqr(
144144
magma_int_t m, magma_int_t n, magma_int_t k, scalar_t* dA,
145145
magma_int_t ldda, scalar_t* tau, scalar_t* dT, magma_int_t nb, magma_int_t* info) {
146-
AT_ERROR("orgqr only takes float or doule Tensors");
146+
AT_ERROR("orgqr only takes float or double Tensors");
147+
}
148+
149+
template<class scalar_t>
150+
void magmaSymeig(
151+
magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, scalar_t* dA, magma_int_t ldda,
152+
scalar_t* w, scalar_t* wA, magma_int_t ldwa, scalar_t* work, magma_int_t lwork,
153+
magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) {
154+
AT_ERROR("symeig only takes float or double Tensors");
147155
}
148156

149157
template<>
@@ -405,6 +413,22 @@ void magmaOrgqr<float>(
405413
float* tau, float* dT, magma_int_t nb, magma_int_t* info) {
406414
magma_sorgqr_gpu(m, n, k, dA, ldda, tau, dT, nb, info);
407415
}
416+
417+
template<>
418+
void magmaSymeig<double>(
419+
magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, double* dA, magma_int_t ldda,
420+
double* w, double* wA, magma_int_t ldwa, double* work, magma_int_t lwork,
421+
magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) {
422+
magma_dsyevd_gpu(jobz, uplo, n, dA, ldda, w, wA, ldwa, work, lwork, iwork, liwork, info);
423+
}
424+
425+
template<>
426+
void magmaSymeig<float>(
427+
magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, float* dA, magma_int_t ldda,
428+
float* w, float* wA, magma_int_t ldwa, float* work, magma_int_t lwork,
429+
magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) {
430+
magma_ssyevd_gpu(jobz, uplo, n, dA, ldda, w, wA, ldwa, work, lwork, iwork, liwork, info);
431+
}
408432
#endif
409433

410434
#define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \
@@ -1123,6 +1147,93 @@ std::tuple<Tensor,Tensor> _qr_helper_cuda(const Tensor& self, bool some) {
11231147
r_working_copy.narrow_copy(-2, 0, n_columns_q).triu_());
11241148
}
11251149

1150+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1151+
1152+
template <typename scalar_t>
1153+
static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool upper, std::vector<int64_t>& infos) {
1154+
#ifndef USE_MAGMA
1155+
AT_ERROR("symeig: MAGMA library not found in "
1156+
"compilation. Please rebuild with MAGMA.");
1157+
#else
1158+
auto self_data = self.data<scalar_t>();
1159+
auto eigvals_data = eigvals.data<scalar_t>();
1160+
auto self_matrix_stride = matrixStride(self);
1161+
auto eigvals_stride = eigvals.size(-1);
1162+
int64_t batch_size = batchCount(self);
1163+
magma_int_t n = magma_int_cast(self.size(-1), "n");
1164+
1165+
magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower;
1166+
magma_vec_t jobz = eigenvectors ? MagmaVec : MagmaNoVec;
1167+
1168+
scalar_t* wA;
1169+
ALLOCATE_ARRAY(wA, scalar_t, n * n, self);
1170+
1171+
magma_int_t info;
1172+
// Run once, first to get the optimum work sizes.
1173+
// Since we deal with batches of matrices with the same dimensions, doing this outside
1174+
// the loop saves (batch_size - 1) workspace queries which would provide the same result
1175+
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
1176+
magma_int_t lwork = -1;
1177+
scalar_t wkopt;
1178+
magma_int_t liwork = -1;
1179+
magma_int_t iwkopt;
1180+
magmaSymeig<scalar_t>(jobz, uplo, n, self_data, n, eigvals_data, wA, n, &wkopt, lwork, &iwkopt, liwork, &info);
1181+
1182+
scalar_t* work;
1183+
magma_int_t* iwork;
1184+
lwork = magma_int_cast(wkopt, "work_size");
1185+
liwork = magma_int_cast(iwkopt, "iwork_size");
1186+
ALLOCATE_ARRAY(work, scalar_t, lwork, self);
1187+
ALLOCATE_ARRAY(iwork, magma_int_t, liwork, self);
1188+
1189+
for (int64_t i = 0; i < batch_size; i++) {
1190+
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
1191+
scalar_t* eigvals_working_ptr = &eigvals_data[i * eigvals_stride];
1192+
magmaSymeig<scalar_t>(jobz, uplo, n, self_working_ptr, n, eigvals_working_ptr,
1193+
wA, n, work, lwork, iwork, liwork, &info);
1194+
infos[i] = info;
1195+
if (info != 0) {
1196+
return;
1197+
}
1198+
}
1199+
#endif
1200+
}
1201+
1202+
std::tuple<Tensor, Tensor> _symeig_helper_cuda(const Tensor& self, bool eigenvectors, bool upper) {
1203+
std::vector<int64_t> infos(batchCount(self), 0);
1204+
1205+
auto self_sizes = self.sizes().vec();
1206+
self_sizes.pop_back();
1207+
1208+
// We create temporary tensors on the CPU, because tensors on the GPU
1209+
// cause segfault when passed to magmaSymeig. The data is later
1210+
// moved to the appropriate device.
1211+
// In the case where self.numel() == 0, we just return an empty tensor of
1212+
// dimensions on the CUDA (to avoid the unnecessary "to(at::kCUDA)")
1213+
auto eigvals_working_copy = self.numel() == 0
1214+
? at::empty(self_sizes, self.options())
1215+
: at::empty(self_sizes, self.options().device(at::kCPU));
1216+
1217+
if (self.numel() == 0) {
1218+
return std::tuple<Tensor, Tensor>(eigvals_working_copy, at::empty_like(self));
1219+
}
1220+
1221+
auto self_working_copy = cloneBatchedColumnMajor(self);
1222+
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "symeig_cuda", [&]{
1223+
apply_symeig<scalar_t>(self_working_copy, eigvals_working_copy, eigenvectors, upper, infos);
1224+
});
1225+
1226+
if (!eigenvectors) {
1227+
self_working_copy.zero_();
1228+
}
1229+
if (self.dim() > 2) {
1230+
batchCheckErrors(infos, "symeig_cuda");
1231+
} else {
1232+
singleCheckErrors(infos[0], "symeig_cuda");
1233+
}
1234+
return std::tuple<Tensor, Tensor>(eigvals_working_copy.to(self.device()), self_working_copy);
1235+
}
1236+
11261237
}} // namespace at::native
11271238

11281239
#undef ALLOCATE_ARRAY

aten/src/ATen/native/native_functions.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3603,15 +3603,15 @@
36033603
CUDA: _triangular_solve_helper_cuda
36043604

36053605
- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
3606-
dispatch:
3607-
CPU: legacy::cpu::_th_symeig_out
3608-
CUDA: legacy::cuda::_th_symeig_out
36093606

36103607
- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
36113608
variants: method, function
3609+
3610+
- func: _symeig_helper(Tensor self, bool eigenvectors, bool upper) -> (Tensor, Tensor)
3611+
variants: function
36123612
dispatch:
3613-
CPU: legacy::cpu::_th_symeig
3614-
CUDA: legacy::cuda::_th_symeig
3613+
CPU: _symeig_helper_cpu
3614+
CUDA: _symeig_helper_cuda
36153615

36163616
- func: eig(Tensor self, bool eigenvectors=False, *, Tensor(a!) e, Tensor(b!) v) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
36173617
dispatch:

aten/src/TH/generic/THLapack.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info);
77
TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info);
8-
TH_EXTERNC void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
9-
TH_EXTERNC void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);
108
TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
119
TH_EXTERNC void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
1210
TH_EXTERNC void dgesdd_(char *jobz, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *iwork, int *info);
@@ -40,21 +38,6 @@ void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, s
4038
#endif
4139
}
4240

43-
/* Compute all eigenvalues and, optionally, eigenvectors of a real symmetric
44-
matrix A */
45-
void THLapack_(syev)(char jobz, char uplo, int n, scalar_t *a, int lda, scalar_t *w, scalar_t *work, int lwork, int *info)
46-
{
47-
#ifdef USE_LAPACK
48-
#if defined(TH_REAL_IS_DOUBLE)
49-
dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
50-
#else
51-
ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
52-
#endif
53-
#else
54-
THError("syev : Lapack library not found in compile time\n");
55-
#endif
56-
}
57-
5841
/* Compute for an N-by-N real nonsymmetric matrix A, the eigenvalues and,
5942
optionally, the left and/or right eigenvectors */
6043
void THLapack_(geev)(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info)

aten/src/TH/generic/THLapack.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
/* ||AX-B|| */
66
TH_API void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, scalar_t *work, int lwork, int *info);
7-
/* Eigenvals */
8-
TH_API void THLapack_(syev)(char jobz, char uplo, int n, scalar_t *a, int lda, scalar_t *w, scalar_t *work, int lwork, int *info);
97
/* Non-sym eigenvals */
108
TH_API void THLapack_(geev)(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info);
119
/* svd */

0 commit comments

Comments
 (0)