Skip to content

Commit 14cbd9a

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Implement torch.pinverse : Pseudo-inverse (#9052)
Summary: 1. Used SVD to compute. 2. Tests in test_autograd, test_cuda and test_torch 3. Doc strings in _torch_docs.py and _tensor_docs.py Closes #6187 Closes #9052 Reviewed By: soumith Differential Revision: D8714628 Pulled By: SsnL fbshipit-source-id: 7e006c9d138b9f49e703bd0ffdabe6253be78dd9
1 parent f6027bb commit 14cbd9a

File tree

9 files changed

+136
-29
lines changed

9 files changed

+136
-29
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@ static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor
1919
p.squeeze_(0);
2020
lu.squeeze_(0);
2121
int int_info = info.squeeze_().toCInt();
22-
if (int_info < 0) {
23-
std::ostringstream ss;
24-
ss << "LU factorization (getrf) failed with info = " << int_info;
25-
throw std::runtime_error(ss.str());
26-
}
22+
AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info);
2723
auto n = self.size(0);
2824
auto num_exchanges = (at::arange(1, n + 1, p.type()) != p).nonzero().size(0);
2925
if (num_exchanges % 2 == 1) {
@@ -34,13 +30,10 @@ static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor
3430
}
3531

3632
Tensor det(const Tensor& self) {
37-
if (!at::isFloatingType(self.type().scalarType()) ||
38-
self.dim() != 2 || self.size(0) != self.size(1)) {
39-
std::ostringstream ss;
40-
ss << "det(" << self.type() << "{" << self.sizes() << "}): expected a 2D "
41-
<< "square tensor of floating types";
42-
throw std::runtime_error(ss.str());
43-
}
33+
AT_CHECK(at::isFloatingType(self.type().scalarType()) &&
34+
self.dim() == 2 && self.size(0) == self.size(1),
35+
"det(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
36+
"of floating types");
4437
double det_P;
4538
Tensor diag_U;
4639
int info;
@@ -53,13 +46,10 @@ Tensor det(const Tensor& self) {
5346
}
5447

5548
Tensor logdet(const Tensor& self) {
56-
if (!at::isFloatingType(self.type().scalarType()) ||
57-
self.dim() != 2 || self.size(0) != self.size(1)) {
58-
std::ostringstream ss;
59-
ss << "logdet(" << self.type() << "{" << self.sizes() << "}): expected a "
60-
<< "2D square tensor of floating types";
61-
throw std::runtime_error(ss.str());
62-
}
49+
AT_CHECK(at::isFloatingType(self.type().scalarType()) &&
50+
self.dim() == 2 && self.size(0) == self.size(1),
51+
"logdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
52+
"of floating types");
6353
double det_P;
6454
Tensor diag_U, det;
6555
int info;
@@ -77,13 +67,10 @@ Tensor logdet(const Tensor& self) {
7767
}
7868

7969
std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
80-
if (!at::isFloatingType(self.type().scalarType()) ||
81-
self.dim() != 2 || self.size(0) != self.size(1)) {
82-
std::ostringstream ss;
83-
ss << "slogdet(" << self.type() << "{" << self.sizes() << "}): expected a "
84-
<< "2D square tensor of floating types";
85-
throw std::runtime_error(ss.str());
86-
}
70+
AT_CHECK(at::isFloatingType(self.type().scalarType()) &&
71+
self.dim() == 2 && self.size(0) == self.size(1),
72+
"slogdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
73+
"of floating types");
8774
double det_P;
8875
Tensor diag_U, det;
8976
int info;
@@ -96,10 +83,19 @@ std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
9683
return std::make_tuple(det.sign(), diag_U.abs_().log_().sum());
9784
}
9885

86+
Tensor pinverse(const Tensor& self, double rcond) {
87+
AT_CHECK(at::isFloatingType(self.type().scalarType()) && self.dim() == 2,
88+
"pinverse(", self.type(), "{", self.sizes(), "}): expected a 2D tensor "
89+
"of floating types");
90+
Tensor U, S, V;
91+
std::tie(U, S, V) = self.svd();
92+
double max_val = S[0].toCDouble();
93+
Tensor S_pseudoinv = at::where(S > rcond * max_val, S.reciprocal(), at::zeros({}, self.options()));
94+
return V.mm(S_pseudoinv.diag().mm(U.t()));
95+
}
96+
9997
static void check_1d(const Tensor& t, const char* arg, const char* fn) {
100-
if (t.dim() != 1) {
101-
AT_ERROR(fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
102-
}
98+
AT_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
10399
}
104100

105101
Tensor ger(const Tensor& self, const Tensor& vec2) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,8 @@
903903

904904
- func: pin_memory(Tensor self) -> Tensor
905905

906+
- func: pinverse(Tensor self, double rcond=1e-15) -> Tensor
907+
906908
- func: rand(IntList size, *, TensorOptions options={}) -> Tensor
907909
variants: function
908910

docs/source/tensors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ view of a storage and defines numeric operations on it.
307307
.. automethod:: ormqr
308308
.. automethod:: permute
309309
.. automethod:: pin_memory
310+
.. automethod:: pinverse
310311
.. automethod:: potrf
311312
.. automethod:: potri
312313
.. automethod:: potrs

docs/source/torch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ BLAS and LAPACK Operations
289289
.. autofunction:: mv
290290
.. autofunction:: orgqr
291291
.. autofunction:: ormqr
292+
.. autofunction:: pinverse
292293
.. autofunction:: potrf
293294
.. autofunction:: potri
294295
.. autofunction:: potrs

test/test_autograd.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,29 @@ def run_test(input_size, exponent):
20402040
run_test((10, 10), torch.zeros(10, 10))
20412041
run_test((10,), 0)
20422042

2043+
def test_pinverse(self):
2044+
# Why is pinverse tested this way, and not ordinarily as other linear algebra methods?
2045+
# 1. Pseudo-inverses are not generally continuous, which means that they are not differentiable
2046+
# 2. Derivatives for pseudo-inverses exist typically for constant rank (Golub et al, 1973)
2047+
# 3. This method creates two orthogonal matrices, and a constructs a test case with large
2048+
# singular values (given by x to the function).
2049+
# 4. This will ensure that small perturbations don't affect the rank of matrix, in which case
2050+
# a derivative exists.
2051+
# 5. This test exists since pinverse is implemented using SVD, and is hence a backpropable method
2052+
m, n = 5, 10
2053+
U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n
2054+
V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n
2055+
2056+
def func(x):
2057+
S = torch.cat([x, torch.zeros(n - m)], 0)
2058+
M = U.mm(torch.diag(S)).mm(V.t())
2059+
return M.pinverse()
2060+
2061+
gradcheck(func, [torch.rand(m) + 1])
2062+
gradcheck(func, [torch.rand(m) + 10])
2063+
gradgradcheck(func, [torch.rand(m) + 1])
2064+
gradgradcheck(func, [torch.rand(m) + 10])
2065+
20432066
def test_profiler(self):
20442067
x = torch.randn(10, 10)
20452068

test/test_cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,10 @@ def test_caching_pinned_memory_multi_gpu(self):
13851385
def _select_broadcastable_dims(dims_full=None):
13861386
return TestTorch._select_broadcastable_dims(dims_full)
13871387

1388+
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
1389+
def test_pinverse(self):
1390+
TestTorch._test_pinverse(self, lambda t: t.cuda())
1391+
13881392
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
13891393
def test_det_logdet_slogdet(self):
13901394
TestTorch._test_det_logdet_slogdet(self, lambda t: t.cuda())

test/test_torch.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4041,6 +4041,33 @@ def test_inverse(self):
40414041
self.assertFalse(MII.is_contiguous(), 'MII is contiguous')
40424042
self.assertEqual(MII, MI, 0, 'inverse value in-place')
40434043

4044+
@staticmethod
4045+
def _test_pinverse(self, conv_fn):
4046+
def run_test(M):
4047+
# Testing against definition for pseudo-inverses
4048+
MPI = torch.pinverse(M)
4049+
self.assertEqual(M, M.mm(MPI).mm(M), 1e-8, 'pseudo-inverse condition 1')
4050+
self.assertEqual(MPI, MPI.mm(M).mm(MPI), 1e-8, 'pseudo-inverse condition 2')
4051+
self.assertEqual(M.mm(MPI), (M.mm(MPI)).t(), 1e-8, 'pseudo-inverse condition 3')
4052+
self.assertEqual(MPI.mm(M), (MPI.mm(M)).t(), 1e-8, 'pseudo-inverse condition 4')
4053+
4054+
# Square matrix
4055+
M = conv_fn(torch.randn(5, 5))
4056+
run_test(M)
4057+
4058+
# Rectangular matrix
4059+
M = conv_fn(torch.randn(3, 4))
4060+
run_test(M)
4061+
4062+
# Test inverse and pseudo-inverse for invertible matrix
4063+
M = torch.randn(5, 5)
4064+
M = conv_fn(M.mm(M.t()))
4065+
self.assertEqual(conv_fn(torch.eye(5)), M.pinverse().mm(M), 1e-7, 'pseudo-inverse for invertible matrix')
4066+
4067+
@skipIfNoLapack
4068+
def test_pinverse(self):
4069+
self._test_pinverse(self, conv_fn=lambda x: x)
4070+
40444071
@staticmethod
40454072
def _test_det_logdet_slogdet(self, conv_fn):
40464073
def reference_det(M):

torch/_tensor_docs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2522,3 +2522,10 @@ def callable(a, b) -> number
25222522
25232523
See :func:`torch.slogdet`
25242524
""")
2525+
2526+
add_docstr_all('pinverse',
2527+
r"""
2528+
pinverse() -> Tensor
2529+
2530+
See :func:`torch.pinverse`
2531+
""")

torch/_torch_docs.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5264,6 +5264,52 @@ def parse_kwargs(desc):
52645264
(tensor(-1.), tensor(1.5731))
52655265
""")
52665266

5267+
add_docstr(torch.pinverse,
5268+
r"""
5269+
pinverse(input, rcond=1e-15) -> Tensor
5270+
5271+
Calculates the pseudo-inverse (also known as the Moore-Penrose inverse) of a 2D tensor.
5272+
Please look at `Moore-Penrose inverse`_ for more details
5273+
5274+
.. note::
5275+
This method is implemented using the Singular Value Decomposition.
5276+
5277+
.. note::
5278+
The pseudo-inverse is not necessarily a continuous function in the elements of the matrix `[1]`_.
5279+
Therefore, derivatives are not always existent, and exist for a constant rank only `[2]`_.
5280+
However, this method is backprop-able due to the implementation by using SVD results, and
5281+
could be unstable. Double-backward will also be unstable due to the usage of SVD internally.
5282+
See :meth:`~torch.svd` for more details.
5283+
5284+
Arguments:
5285+
input (Tensor): The input 2D tensor of dimensions :math:`m \times n`
5286+
rcond (float): A floating point value to determine the cutoff for small singular values.
5287+
Default: 1e-15
5288+
5289+
Returns:
5290+
The pseudo-inverse of :attr:`input` of dimensions :math:`n \times m`
5291+
5292+
Example::
5293+
5294+
>>> input = torch.randn(3, 5)
5295+
>>> input
5296+
tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132],
5297+
[-1.1143, -0.3662, 0.3042, 1.6374, -0.9294],
5298+
[-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]])
5299+
>>> torch.pinverse(input)
5300+
tensor([[ 0.0600, -0.1933, -0.2090],
5301+
[-0.0903, -0.0817, -0.4752],
5302+
[-0.7124, -0.1631, -0.2272],
5303+
[ 0.1356, 0.3933, -0.5023],
5304+
[-0.0308, -0.1725, -0.5216]])
5305+
5306+
.. _Moore-Penrose inverse: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse
5307+
5308+
.. _[1]: https://epubs.siam.org/doi/10.1137/0117004
5309+
5310+
.. _[2]: https://www.jstor.org/stable/2156365
5311+
""")
5312+
52675313
add_docstr(torch.fft,
52685314
r"""
52695315
fft(input, signal_ndim, normalized=False) -> Tensor

0 commit comments

Comments
 (0)