Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 25 additions & 29 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor
p.squeeze_(0);
lu.squeeze_(0);
int int_info = info.squeeze_().toCInt();
if (int_info < 0) {
std::ostringstream ss;
ss << "LU factorization (getrf) failed with info = " << int_info;
throw std::runtime_error(ss.str());
}
AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info);
auto n = self.size(0);
auto num_exchanges = (at::arange(1, n + 1, p.type()) != p).nonzero().size(0);
if (num_exchanges % 2 == 1) {
Expand All @@ -34,13 +30,10 @@ static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor
}

Tensor det(const Tensor& self) {
if (!at::isFloatingType(self.type().scalarType()) ||
self.dim() != 2 || self.size(0) != self.size(1)) {
std::ostringstream ss;
ss << "det(" << self.type() << "{" << self.sizes() << "}): expected a 2D "
<< "square tensor of floating types";
throw std::runtime_error(ss.str());
}
AT_CHECK(at::isFloatingType(self.type().scalarType()) &&
self.dim() == 2 && self.size(0) == self.size(1),
"det(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
"of floating types");
double det_P;
Tensor diag_U;
int info;
Expand All @@ -53,13 +46,10 @@ Tensor det(const Tensor& self) {
}

Tensor logdet(const Tensor& self) {
if (!at::isFloatingType(self.type().scalarType()) ||
self.dim() != 2 || self.size(0) != self.size(1)) {
std::ostringstream ss;
ss << "logdet(" << self.type() << "{" << self.sizes() << "}): expected a "
<< "2D square tensor of floating types";
throw std::runtime_error(ss.str());
}
AT_CHECK(at::isFloatingType(self.type().scalarType()) &&
self.dim() == 2 && self.size(0) == self.size(1),
"logdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
"of floating types");
double det_P;
Tensor diag_U, det;
int info;
Expand All @@ -77,13 +67,10 @@ Tensor logdet(const Tensor& self) {
}

std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
if (!at::isFloatingType(self.type().scalarType()) ||
self.dim() != 2 || self.size(0) != self.size(1)) {
std::ostringstream ss;
ss << "slogdet(" << self.type() << "{" << self.sizes() << "}): expected a "
<< "2D square tensor of floating types";
throw std::runtime_error(ss.str());
}
AT_CHECK(at::isFloatingType(self.type().scalarType()) &&
self.dim() == 2 && self.size(0) == self.size(1),
"slogdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
"of floating types");
double det_P;
Tensor diag_U, det;
int info;
Expand All @@ -96,10 +83,19 @@ std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
return std::make_tuple(det.sign(), diag_U.abs_().log_().sum());
}

Tensor pinverse(const Tensor& self, double rcond) {
AT_CHECK(at::isFloatingType(self.type().scalarType()) && self.dim() == 2,
"pinverse(", self.type(), "{", self.sizes(), "}): expected a 2D tensor "
"of floating types");
Tensor U, S, V;
std::tie(U, S, V) = self.svd();
double max_val = S[0].toCDouble();
Tensor S_pseudoinv = at::where(S > rcond * max_val, S.reciprocal(), at::zeros({}, self.options()));
return V.mm(S_pseudoinv.diag().mm(U.t()));
}

static void check_1d(const Tensor& t, const char* arg, const char* fn) {
if (t.dim() != 1) {
AT_ERROR(fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
}
AT_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
}

Tensor ger(const Tensor& self, const Tensor& vec2) {
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,8 @@

- func: pin_memory(Tensor self) -> Tensor

- func: pinverse(Tensor self, double rcond=1e-15) -> Tensor

- func: rand(IntList size, *, TensorOptions options={}) -> Tensor
variants: function

Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: ormqr
.. automethod:: permute
.. automethod:: pin_memory
.. automethod:: pinverse
.. automethod:: potrf
.. automethod:: potri
.. automethod:: potrs
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ BLAS and LAPACK Operations
.. autofunction:: mv
.. autofunction:: orgqr
.. autofunction:: ormqr
.. autofunction:: pinverse
.. autofunction:: potrf
.. autofunction:: potri
.. autofunction:: potrs
Expand Down
23 changes: 23 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2040,6 +2040,29 @@ def run_test(input_size, exponent):
run_test((10, 10), torch.zeros(10, 10))
run_test((10,), 0)

def test_pinverse(self):
# Why is pinverse tested this way, and not ordinarily as other linear algebra methods?
# 1. Pseudo-inverses are not generally continuous, which means that they are not differentiable
# 2. Derivatives for pseudo-inverses exist typically for constant rank (Golub et al, 1973)
# 3. This method creates two orthogonal matrices, and a constructs a test case with large
# singular values (given by x to the function).
# 4. This will ensure that small perturbations don't affect the rank of matrix, in which case
# a derivative exists.
# 5. This test exists since pinverse is implemented using SVD, and is hence a backpropable method
m, n = 5, 10
U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n
V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n

def func(x):
S = torch.cat([x, torch.zeros(n - m)], 0)
M = U.mm(torch.diag(S)).mm(V.t())
return M.pinverse()

gradcheck(func, [torch.rand(m) + 1])
gradcheck(func, [torch.rand(m) + 10])

This comment was marked as off-topic.

This comment was marked as off-topic.

gradgradcheck(func, [torch.rand(m) + 1])
gradgradcheck(func, [torch.rand(m) + 10])

def test_profiler(self):
x = torch.randn(10, 10)

Expand Down
4 changes: 4 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,10 @@ def test_caching_pinned_memory_multi_gpu(self):
def _select_broadcastable_dims(dims_full=None):
return TestTorch._select_broadcastable_dims(dims_full)

@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_pinverse(self):
TestTorch._test_pinverse(self, lambda t: t.cuda())

@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_det_logdet_slogdet(self):
TestTorch._test_det_logdet_slogdet(self, lambda t: t.cuda())
Expand Down
27 changes: 27 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4041,6 +4041,33 @@ def test_inverse(self):
self.assertFalse(MII.is_contiguous(), 'MII is contiguous')
self.assertEqual(MII, MI, 0, 'inverse value in-place')

@staticmethod
def _test_pinverse(self, conv_fn):
def run_test(M):
# Testing against definition for pseudo-inverses
MPI = torch.pinverse(M)
self.assertEqual(M, M.mm(MPI).mm(M), 1e-8, 'pseudo-inverse condition 1')
self.assertEqual(MPI, MPI.mm(M).mm(MPI), 1e-8, 'pseudo-inverse condition 2')
self.assertEqual(M.mm(MPI), (M.mm(MPI)).t(), 1e-8, 'pseudo-inverse condition 3')
self.assertEqual(MPI.mm(M), (MPI.mm(M)).t(), 1e-8, 'pseudo-inverse condition 4')

# Square matrix
M = conv_fn(torch.randn(5, 5))
run_test(M)

# Rectangular matrix
M = conv_fn(torch.randn(3, 4))
run_test(M)

# Test inverse and pseudo-inverse for invertible matrix
M = torch.randn(5, 5)
M = conv_fn(M.mm(M.t()))
self.assertEqual(conv_fn(torch.eye(5)), M.pinverse().mm(M), 1e-7, 'pseudo-inverse for invertible matrix')

@skipIfNoLapack
def test_pinverse(self):
self._test_pinverse(self, conv_fn=lambda x: x)

@staticmethod
def _test_det_logdet_slogdet(self, conv_fn):
def reference_det(M):
Expand Down
7 changes: 7 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2522,3 +2522,10 @@ def callable(a, b) -> number

See :func:`torch.slogdet`
""")

add_docstr_all('pinverse',
r"""
pinverse() -> Tensor

See :func:`torch.pinverse`
""")
46 changes: 46 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5264,6 +5264,52 @@ def parse_kwargs(desc):
(tensor(-1.), tensor(1.5731))
""")

add_docstr(torch.pinverse,
r"""
pinverse(input, rcond=1e-15) -> Tensor

Calculates the pseudo-inverse (also known as the Moore-Penrose inverse) of a 2D tensor.
Please look at `Moore-Penrose inverse`_ for more details

.. note::
This method is implemented using the Singular Value Decomposition.

.. note::
The pseudo-inverse is not necessarily a continuous function in the elements of the matrix `[1]`_.
Therefore, derivatives are not always existent, and exist for a constant rank only `[2]`_.
However, this method is backprop-able due to the implementation by using SVD results, and
could be unstable. Double-backward will also be unstable due to the usage of SVD internally.
See :meth:`~torch.svd` for more details.

Arguments:
input (Tensor): The input 2D tensor of dimensions :math:`m \times n`
rcond (float): A floating point value to determine the cutoff for small singular values.
Default: 1e-15

Returns:
The pseudo-inverse of :attr:`input` of dimensions :math:`n \times m`

Example::

>>> input = torch.randn(3, 5)
>>> input
tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132],
[-1.1143, -0.3662, 0.3042, 1.6374, -0.9294],
[-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]])
>>> torch.pinverse(input)
tensor([[ 0.0600, -0.1933, -0.2090],
[-0.0903, -0.0817, -0.4752],
[-0.7124, -0.1631, -0.2272],
[ 0.1356, 0.3933, -0.5023],
[-0.0308, -0.1725, -0.5216]])

.. _Moore-Penrose inverse: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse

.. _[1]: https://epubs.siam.org/doi/10.1137/0117004

.. _[2]: https://www.jstor.org/stable/2156365
""")

add_docstr(torch.fft,
r"""
fft(input, signal_ndim, normalized=False) -> Tensor
Expand Down