You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Port symeig to ATen and enable batching of inputs (#21858)
Summary:
Changelog:
- Port `symeig` from TH/THC to ATen
- Enable batching of matrix inputs for `symeig`
- Modify derivative computation based on batching
- Update docs to reflect the change
Pull Request resolved: #21858
Test Plan: - Added additional tests in `test_torch.py` (with a port to `test_cuda.py`) and `common_methods_invocations.py` to test if both the port and batching work.
Differential Revision: D15981789
Pulled By: soumith
fbshipit-source-id: ab9af8361f8608db42318aabc8421bd99a1ca7ae
Copy file name to clipboardExpand all lines: aten/src/TH/generic/THLapack.cpp
-17Lines changed: 0 additions & 17 deletions
Original file line number
Diff line number
Diff line change
@@ -5,8 +5,6 @@
5
5
6
6
TH_EXTERNC voiddgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info);
7
7
TH_EXTERNC voidsgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info);
8
-
TH_EXTERNC voiddsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
9
-
TH_EXTERNC voidssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);
10
8
TH_EXTERNC voiddgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
11
9
TH_EXTERNC voidsgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
12
10
TH_EXTERNC voiddgesdd_(char *jobz, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *iwork, int *info);
@@ -40,21 +38,6 @@ void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, s
40
38
#endif
41
39
}
42
40
43
-
/* Compute all eigenvalues and, optionally, eigenvectors of a real symmetric
44
-
matrix A */
45
-
voidTHLapack_(syev)(char jobz, char uplo, int n, scalar_t *a, int lda, scalar_t *w, scalar_t *work, int lwork, int *info)
46
-
{
47
-
#ifdef USE_LAPACK
48
-
#if defined(TH_REAL_IS_DOUBLE)
49
-
dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
50
-
#else
51
-
ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
52
-
#endif
53
-
#else
54
-
THError("syev : Lapack library not found in compile time\n");
55
-
#endif
56
-
}
57
-
58
41
/* Compute for an N-by-N real nonsymmetric matrix A, the eigenvalues and,
59
42
optionally, the left and/or right eigenvectors */
60
43
voidTHLapack_(geev)(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info)
0 commit comments