Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions aten/src/ATen/native/Gesv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ template<> void lapackGesv<double>(
#endif

template <typename scalar_t>
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
#ifndef USE_LAPACK
AT_ERROR("gesv: LAPACK library not found in compilation");
#endif
Expand Down Expand Up @@ -117,8 +117,7 @@ std::tuple<Tensor&,Tensor&> gesv_out(
Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) {
if (self.dim() > 2 || A.dim() > 2) {
AT_ERROR("torch.gesv() with the `out` keyword does not support batching. "
"b.dim() (%lld) and A.dim() (%lld) must both be 2.",
(long long)self.dim(), (long long)A.dim());
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
}
return at::_gesv_single_out(solution, lu, self, A);
}
Expand Down
15 changes: 6 additions & 9 deletions aten/src/ATen/native/Gesv.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,23 @@ namespace at { namespace native {
static inline void checkInputs(const Tensor& self, const Tensor& A) {
if (A.size(-1) != A.size(-2)) {
AT_ERROR("A must be batches of square matrices, "
"but they are %lld by %lld matrices",
"but they are ", A.size(-1), " by ", A.size(-2), " matrices",
(long long)A.size(-1), (long long)A.size(-2));
}
if (A.size(-1) != self.size(-2)) {
AT_ERROR("Incompatible matrix sizes for matmul: each A "
"matrix is %llu by %lld but each b matrix is %lld by %lld.",
(long long)A.size(-1), (long long)A.size(-1),
(long long)self.size(-2), (long long)self.size(-1));
"matrix is ", A.size(-1), " by ", A.size(-1),
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
}
}

static inline void checkErrors(std::vector<int64_t> infos) {
static inline void checkErrors(std::vector<int64_t>& infos) {
for (size_t i = 0; i < infos.size(); i++) {
auto info = infos[i];
if (info < 0) {
AT_ERROR("gesv: For batch %lld: Argument %lld has illegal value",
(long long)i, -info);
AT_ERROR("gesv: For batch ", i, ": Argument ", -info, " has illegal value.");
} else if (info > 0) {
AT_ERROR("gesv: For batch %lld: U(%lld,%lld) is zero, singular U.",
(long long)i, info, info);
AT_ERROR("gesv: For batch ", i, ": U(", info, ",", info, ") is zero, singular U.");
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Gesv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ static inline Storage pin_memory(int64_t size, Tensor dummy) {
name = static_cast<type*>(storage_##name.data());

template <typename scalar_t>
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
#ifndef USE_MAGMA
AT_ERROR("gesv: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
Expand Down
65 changes: 65 additions & 0 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,68 @@ def find_free_port():
sockname = sock.getsockname()
sock.close()
return sockname[1]


# Methods for matrix generation
# Used in test_autograd.py and test_torch.py
def prod_single_zero(dim_size):
result = torch.randn(dim_size, dim_size)
result[0, 1] = 0
return result


def random_square_matrix_of_rank(l, rank):
assert rank <= l
A = torch.randn(l, l)
u, s, v = A.svd()
for i in range(l):
if i >= rank:
s[i] = 0
elif s[i] == 0:
s[i] = 1
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))


def random_symmetric_matrix(l):
A = torch.randn(l, l)
for i in range(l):
for j in range(i):
A[i, j] = A[j, i]
return A


def random_symmetric_psd_matrix(l):
A = torch.randn(l, l)
return A.mm(A.transpose(0, 1))


def random_symmetric_pd_matrix(l, eps=1e-5):
A = torch.randn(l, l)
return A.mm(A.transpose(0, 1)) + torch.eye(l) * eps


def make_nonzero_det(A, sign=None, min_singular_value=0.1):
u, s, v = A.svd()
s[s < min_singular_value] = min_singular_value
A = u.mm(torch.diag(s)).mm(v.t())
det = A.det().item()
if sign is not None:
if (det < 0) ^ (sign < 0):
A[0, :].neg_()
return A


def random_fullrank_matrix_distinct_singular_value(l, *batches):
if len(batches) == 0:
A = torch.randn(l, l)
u, _, v = A.svd()
s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
return u.mm(torch.diag(s)).mm(v.t())
else:
all_matrices = []
for _ in range(0, torch.prod(torch.as_tensor(batches)).item()):
A = torch.randn(l, l)
u, _, v = A.svd()
s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
all_matrices.append(u.mm(torch.diag(s)).mm(v.t()))
return torch.stack(all_matrices).reshape(*(batches + (l, l)))
76 changes: 15 additions & 61 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
from torch.autograd.gradcheck import gradgradcheck, gradcheck
from torch.autograd.function import once_differentiable
from torch.autograd.profiler import profile
from common import TEST_MKL, TestCase, run_tests, skipIfNoLapack, \
suppress_warnings, skipIfRocm
from common import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
suppress_warnings, skipIfRocm,
prod_single_zero, random_square_matrix_of_rank,
random_symmetric_matrix, random_symmetric_psd_matrix,
random_symmetric_pd_matrix, make_nonzero_det,
random_fullrank_matrix_distinct_singular_value)
from torch.autograd import Variable, Function, detect_anomaly
from torch.autograd.function import InplaceFunction
from torch.testing import make_non_contiguous, randn_like
Expand Down Expand Up @@ -2566,60 +2570,6 @@ def prod_zeros(dim_size, dim_select):
return result


def prod_single_zero(dim_size):
result = torch.randn(dim_size, dim_size)
result[0, 1] = 0
return result


def random_square_matrix_of_rank(l, rank):
assert rank <= l
A = torch.randn(l, l)
u, s, v = A.svd()
for i in range(l):
if i >= rank:
s[i] = 0
elif s[i] == 0:
s[i] = 1
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))


def random_symmetric_matrix(l):
A = torch.randn(l, l)
for i in range(l):
for j in range(i):
A[i, j] = A[j, i]
return A


def random_symmetric_psd_matrix(l):
A = torch.randn(l, l)
return A.mm(A.transpose(0, 1))


def random_symmetric_pd_matrix(l, eps=1e-5):
A = torch.randn(l, l)
return A.mm(A.transpose(0, 1)) + torch.eye(l) * eps


def make_nonzero_det(A, sign=None, min_singular_value=0.1):
u, s, v = A.svd()
s[s < min_singular_value] = min_singular_value
A = u.mm(torch.diag(s)).mm(v.t())
det = A.det().item()
if sign is not None:
if (det < 0) ^ (sign < 0):
A[0, :].neg_()
return A


def random_fullrank_matrix_distinct_singular_value(l):
A = torch.randn(l, l)
u, _, v = A.svd()
s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
return u.mm(torch.diag(s)).mm(v.t())


def uniform_scalar(offset=0, requires_grad=False):
v = torch.rand(()) + offset
v.requires_grad = requires_grad
Expand Down Expand Up @@ -3151,11 +3101,15 @@ class dont_convert(tuple):
'tall_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])),
('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS,
'large', NO_ARGS, [skipIfNoLapack]),
('gesv', (S, S), ((S, S),), '', NO_ARGS, [skipIfNoLapack]),
('gesv', (S, S, S), ((S, S, S),), 'batched', NO_ARGS, [skipIfNoLapack, skipIfRocm]),
('gesv', (2, 3, S, S), ((2, 3, S, S),), 'batched_dims', NO_ARGS, [skipIfNoLapack, skipIfRocm]),
('gesv', (2, 2, S, S), ((1, S, S),), 'batched_broadcast_A', NO_ARGS, [skipIfNoLapack, skipIfRocm]),
('gesv', (1, S, S), ((2, 2, S, S),), 'batched_broadcast_b', NO_ARGS, [skipIfNoLapack, skipIfRocm]),
('gesv', (S, S), (random_fullrank_matrix_distinct_singular_value(S),), '', NO_ARGS, [skipIfNoLapack]),
('gesv', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S),),
'batched', NO_ARGS, [skipIfNoLapack, skipIfRocm]),
('gesv', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3),),
'batched_dims', NO_ARGS, [skipIfNoLapack, skipIfRocm]),
('gesv', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1),),
'batched_broadcast_A', NO_ARGS, [skipIfNoLapack, skipIfRocm]),
('gesv', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2),),
'batched_broadcast_b', NO_ARGS, [skipIfNoLapack, skipIfRocm]),
('fill_', (S, S, S), (1,), 'number'),
('fill_', (), (1,), 'number_scalar'),
# FIXME: we should compute the derivative w.r.t torch.tensor(1)
Expand Down
23 changes: 12 additions & 11 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3965,16 +3965,17 @@ def test_gesv(self):

@staticmethod
def _test_gesv_batched(self, cast):
from common import random_fullrank_matrix_distinct_singular_value as fullrank
# test against gesv: one batch
A = cast(torch.randn(1, 5, 5))
A = cast(fullrank(5, 1))
b = cast(torch.randn(1, 5, 10))
x_exp, LU_exp = torch.gesv(b.squeeze(0), A.squeeze(0))
x, LU = torch.gesv(b, A)
self.assertEqual(x, x_exp.unsqueeze(0))
self.assertEqual(LU, LU_exp.unsqueeze(0))

# test against gesv in a loop: four batches
A = cast(torch.randn(4, 5, 5))
A = cast(fullrank(5, 4))
b = cast(torch.randn(4, 5, 10))

x_exp_list = list()
Expand All @@ -3991,7 +3992,7 @@ def _test_gesv_batched(self, cast):
self.assertEqual(LU, LU_exp)

# basic correctness test
A = cast(torch.randn(3, 5, 5))
A = cast(fullrank(5, 3))
b = cast(torch.randn(3, 5, 10))
x, LU = torch.gesv(b, A)
self.assertEqual(torch.matmul(A, x), b)
Expand All @@ -4001,7 +4002,7 @@ def _test_gesv_batched(self, cast):
return
import numpy
from numpy.linalg import solve
A = cast(torch.randn(2, 2, 2)).permute(1, 0, 2)
A = cast(fullrank(2, 2)).permute(1, 0, 2)
b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0)
x, _ = torch.gesv(b, A)
x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
Expand All @@ -4016,18 +4017,18 @@ def _test_gesv_batched_dims(self, cast):
if not TEST_NUMPY:
return

import numpy
from numpy.linalg import solve
from common import random_fullrank_matrix_distinct_singular_value as fullrank

# test against numpy.linalg.solve
A = cast(torch.randn(2, 1, 3, 4, 4))
A = cast(fullrank(4, 2, 1, 3))
b = cast(torch.randn(2, 1, 3, 4, 6))
x, _ = torch.gesv(b, A)
x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
self.assertEqual(x.data, cast(x_exp))

# test column major format
A = cast(torch.randn(2, 1, 3, 4, 4)).transpose(-2, -1)
A = cast(fullrank(4, 2, 1, 3)).transpose(-2, -1)
b = cast(torch.randn(2, 1, 3, 6, 4)).transpose(-2, -1)
assert not A.is_contiguous()
assert not b.is_contiguous()
Expand All @@ -4036,21 +4037,21 @@ def _test_gesv_batched_dims(self, cast):
self.assertEqual(x.data, cast(x_exp))

# broadcasting b
A = cast(torch.randn(2, 1, 3, 4, 4))
A = cast(fullrank(4, 2, 1, 3))
b = cast(torch.randn(4, 6))
x, _ = torch.gesv(b, A)
x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
self.assertEqual(x.data, cast(x_exp))

# broadcasting A
A = cast(torch.randn(4, 4))
A = cast(fullrank(4))
b = cast(torch.randn(2, 1, 3, 4, 2))
x, _ = torch.gesv(b, A)
x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
self.assertEqual(x.data, cast(x_exp))

# broadcasting both A & b
A = cast(torch.randn(1, 3, 1, 4, 4))
A = cast(fullrank(4, 1, 3, 1))
b = cast(torch.randn(2, 1, 3, 4, 5))
x, _ = torch.gesv(b, A)
x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
Expand Down Expand Up @@ -4631,7 +4632,7 @@ def run_test(M, sign=1):

# Single matrix, but full rank
# This is for negative powers
from test_autograd import random_fullrank_matrix_distinct_singular_value
from common import random_fullrank_matrix_distinct_singular_value
M = conv_fn(random_fullrank_matrix_distinct_singular_value(5))
run_test(M)
run_test(M, sign=-1)
Expand Down