Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 92 additions & 13 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,29 @@ AT_ERROR("solve: MAGMA library not found in "
}

MAGMAQueue magma_queue(b.get_device());
magmaSolveBatched<scalar_t>(
n, nrhs, A_array, n, ipiv_array, b_array, n,
info_array, batch_size, magma_queue);

// Compute as many batches of 65535 possible
// The number of "mini"-batches are floor(batch_size / 65535)
// and these cover floor(batch_size / 65535) * 65535 matrix solves
int64_t mini_batches = batch_size / 65535, mini_idx;
for (mini_idx = 0; mini_idx < mini_batches * 65535; mini_idx += 65535) {
scalar_t** A_array_cur = &A_array[mini_idx];
scalar_t** b_array_cur = &b_array[mini_idx];
magma_int_t** ipiv_array_cur = &ipiv_array[mini_idx];
magma_int_t* info_array_cur = &info_array[mini_idx];

magmaSolveBatched<scalar_t>(
n, nrhs, A_array_cur, n, ipiv_array_cur, b_array_cur, n,
info_array_cur, 65535, magma_queue);
}

// Compute whatever is left = batch_size - floor(batch_size / 65535) * 65535
// which concisely is equal to batch_size % 65535
if (batch_size % 65535 != 0) {
magmaSolveBatched<scalar_t>(
n, nrhs, &A_array[mini_idx], n, &ipiv_array[mini_idx], &b_array[mini_idx], n,
&info_array[mini_idx], batch_size % 65535, magma_queue);
}

for (int64_t i = 0; i < batch_size; i++) {
infos[i] = info_array[i];
Expand Down Expand Up @@ -521,9 +541,28 @@ AT_ERROR("inverse: MAGMA library not found in "
n, n, self_array, n, ipiv_array, info_array,
batch_size, magma_queue);

magmaGetriBatched<scalar_t>(
n, self_array, n, ipiv_array, self_inv_array,
n, info_array, batch_size, magma_queue);
// Compute as many batches of 65535 possible
// The number of "mini"-batches are floor(batch_size / 65535)
// and these cover floor(batch_size / 65535) * 65535 matrix solves
int64_t mini_batches = batch_size / 65535, mini_idx;
for (mini_idx = 0; mini_idx < mini_batches * 65535; mini_idx += 65535) {
scalar_t** self_array_cur = &self_array[mini_idx];
scalar_t** self_inv_array_cur = &self_inv_array[mini_idx];
magma_int_t** ipiv_array_cur = &ipiv_array[mini_idx];
magma_int_t* info_array_cur = &info_array[mini_idx];

magmaGetriBatched<scalar_t>(
n, self_array_cur, n, ipiv_array_cur, self_inv_array_cur,
n, info_array_cur, 65535, magma_queue);
}

// Compute whatever is left = batch_size - floor(batch_size / 65535) * 65535
// which concisely is equal to batch_size % 65535
if (batch_size % 65535 != 0) {
magmaGetriBatched<scalar_t>(
n, &self_array[mini_idx], n, &ipiv_array[mini_idx], &self_inv_array[mini_idx],
n, &info_array[mini_idx], batch_size % 65535, magma_queue);
}

for (int64_t i = 0; i < batch_size; i++) {
infos[i] = info_array[i];
Expand Down Expand Up @@ -590,7 +629,7 @@ AT_ERROR("cholesky_solve: MAGMA library not found in "
magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");

int info_tmp;
int info_tmp = 0;
if (b.dim() == 2) {
magmaCholeskySolve<scalar_t>(uplo, n, nrhs, A_data, n,
b_data, n, &info_tmp);
Expand All @@ -613,9 +652,31 @@ AT_ERROR("cholesky_solve: MAGMA library not found in "
}

MAGMAQueue magma_queue(b.get_device());
magmaCholeskySolveBatched<scalar_t>(
uplo, n, nrhs, A_array, n, b_array, n,
info_tmp, batch_size, magma_queue);

// Compute as many batches of 65535 possible
// The number of "mini"-batches are floor(batch_size / 65535)
// and these cover floor(batch_size / 65535) * 65535 matrix solves
int64_t mini_batches = batch_size / 65535, mini_idx;
for (mini_idx = 0; mini_idx < mini_batches * 65535; mini_idx += 65535) {
scalar_t** A_array_cur = &A_array[mini_idx];
scalar_t** b_array_cur = &b_array[mini_idx];

magmaCholeskySolveBatched<scalar_t>(
uplo, n, nrhs, A_array_cur, n, b_array_cur, n,
info_tmp, 65535, magma_queue);

if (info_tmp != 0) {
break;
}
}

// Compute whatever is left = batch_size - floor(batch_size / 65535) * 65535
// which concisely is equal to batch_size % 65535
if (batch_size % 65535 != 0 && info_tmp == 0) {
magmaCholeskySolveBatched<scalar_t>(
uplo, n, nrhs, &A_array[mini_idx], n, &b_array[mini_idx], n,
info_tmp, batch_size % 65535, magma_queue);
}

info = info_tmp;
}
Expand Down Expand Up @@ -928,9 +989,27 @@ AT_ERROR("cholesky_solve: MAGMA library not found in "
}

MAGMAQueue magma_queue(b.get_device());
magmaTriangularSolveBatched<scalar_t>(
uplo, trans, diag, n, nrhs, A_array, n,
b_array, n, batch_size, magma_queue);

// Compute as many batches of 65535 possible
// The number of "mini"-batches are floor(batch_size / 65535)
// and these cover floor(batch_size / 65535) * 65535 matrix solves
int64_t mini_batches = batch_size / 65535, mini_idx;
for (mini_idx = 0; mini_idx < mini_batches * 65535; mini_idx += 65535) {
scalar_t** A_array_cur = &A_array[mini_idx];
scalar_t** b_array_cur = &b_array[mini_idx];

magmaTriangularSolveBatched<scalar_t>(
uplo, trans, diag, n, nrhs, A_array_cur,
n, b_array_cur, n, 65535, magma_queue);
}

// Compute whatever is left = batch_size - floor(batch_size / 65535) * 65535
// which concisely is equal to batch_size % 65535
if (batch_size % 65535 != 0) {
magmaTriangularSolveBatched<scalar_t>(
uplo, trans, diag, n, nrhs, &A_array[mini_idx],
n, &b_array[mini_idx], n, batch_size % 65535, magma_queue);
}
}
#endif
}
Expand Down
20 changes: 20 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,6 +2178,11 @@ def _select_broadcastable_dims(dims_full=None):
def test_inverse(self):
_TestTorchMixin._test_inverse(self, lambda t: t.cuda())

@slowTest
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_inverse_many_batches(self):
_TestTorchMixin._test_inverse_slow(self, lambda t: t.cuda())

@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_pinverse(self):
_TestTorchMixin._test_pinverse(self, lambda t: t.cuda())
Expand Down Expand Up @@ -2205,6 +2210,11 @@ def test_solve(self):
def test_solve_batched(self):
_TestTorchMixin._test_solve_batched(self, lambda t: t.cuda())

@slowTest
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_solve_batched_many_batches(self):
_TestTorchMixin._test_solve_batched_many_batches(self, lambda t: t.cuda())

@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_solve_batched_dims(self):
_TestTorchMixin._test_solve_batched_dims(self, lambda t: t.cuda())
Expand All @@ -2217,6 +2227,11 @@ def test_cholesky_solve(self):
def test_cholesky_solve_batched(self):
_TestTorchMixin._test_cholesky_solve_batched(self, lambda t: t.cuda())

@slowTest
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_cholesky_solve_batched_many_batches(self):
_TestTorchMixin._test_cholesky_solve_batched_many_batches(self, lambda t: t.cuda())

@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_cholesky_solve_batched_dims(self):
_TestTorchMixin._test_cholesky_solve_batched_dims(self, lambda t: t.cuda())
Expand Down Expand Up @@ -2700,6 +2715,11 @@ def test_triangular_solve(self):
def test_triangular_solve_batched(self):
_TestTorchMixin._test_triangular_solve_batched(self, lambda t: t.cuda())

@slowTest
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_triangular_solve_batched_many_batches(self):
_TestTorchMixin._test_triangular_solve_batched_many_batches(self, lambda t: t.cuda())

@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_triangular_solve_batched_dims(self):
_TestTorchMixin._test_triangular_solve_batched_dims(self, lambda t: t.cuda())
Expand Down
94 changes: 93 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5353,6 +5353,25 @@ def _test_solve_batched(self, cast):
def test_solve_batched(self):
self._test_solve_batched(self, lambda t: t)

@staticmethod
def _test_solve_batched_many_batches(self, cast):
from common_utils import random_fullrank_matrix_distinct_singular_value

A = cast(random_fullrank_matrix_distinct_singular_value(5, 256, 256))
b = cast(torch.randn(5, 1))
x, _ = torch.solve(b, A)
self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1)))

A = cast(random_fullrank_matrix_distinct_singular_value(3))
b = cast(torch.randn(512, 512, 3, 1))
x, _ = torch.solve(b, A)
self.assertEqual(torch.matmul(A, x), b)

@slowTest
@skipIfNoLapack
def test_solve_batched_many_batches(self):
self._test_solve_batched_many_batches(self, lambda t: t.cuda())

@staticmethod
def _test_solve_batched_dims(self, cast):
if not TEST_NUMPY:
Expand Down Expand Up @@ -5610,7 +5629,37 @@ def triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular):

@skipIfNoLapack
def test_triangular_solve_batched(self):
_TestTorchMixin._test_triangular_solve_batched(self, lambda t: t)
self._test_triangular_solve_batched(self, lambda t: t)

@staticmethod
def _test_triangular_solve_batched_many_batches(self, cast):
def triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular):
A = cast(torch.randn(*A_dims))
A = A.triu() if upper else A.tril()
if unitriangular:
A.diagonal(dim1=-2, dim2=-1).fill_(1.)
b = cast(torch.randn(*b_dims))
return A, b

for upper, transpose, unitriangular in product([True, False], repeat=3):
A, b = triangular_solve_test_helper((256, 256, 5, 5), (5, 1), cast, upper, unitriangular)
x, _ = torch.triangular_solve(b, A,
upper=upper, transpose=transpose, unitriangular=unitriangular)
if transpose:
A = A.transpose(-2, -1)
self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1)))

A, b = triangular_solve_test_helper((3, 3), (512, 512, 3, 1), cast, upper, unitriangular)
x, _ = torch.triangular_solve(b, A,
upper=upper, transpose=transpose, unitriangular=unitriangular)
if transpose:
A = A.transpose(-2, -1)
self.assertEqual(torch.matmul(A, x), b)

@slowTest
@skipIfNoLapack
def test_triangular_solve_batched_many_batches(self):
self._test_triangular_solve_batched_many_batches(self, lambda t: t)

@staticmethod
def _test_triangular_solve_batched_dims(self, cast):
Expand Down Expand Up @@ -6041,10 +6090,29 @@ def _test_inverse(self, conv_fn):
expected_inv = torch.as_tensor(inv(matrices.cpu().numpy()))
self.assertEqual(matrices_inverse, conv_fn(expected_inv))

@staticmethod
def _test_inverse_slow(self, conv_fn):
from common_utils import random_fullrank_matrix_distinct_singular_value

matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 256, 256))
matrices_inverse = torch.inverse(matrices)
self.assertEqual(torch.matmul(matrices_inverse, matrices),
conv_fn(torch.eye(5)).expand_as(matrices))

matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 512, 512))
matrices_inverse = torch.inverse(matrices)
self.assertEqual(torch.matmul(matrices, matrices_inverse),
conv_fn(torch.eye(3)).expand_as(matrices))

@skipIfNoLapack
def test_inverse(self):
self._test_inverse(self, lambda t: t)

@slowTest
@skipIfNoLapack
def test_inverse_many_batches(self):
self._test_inverse_slow(self, lambda t: t)

@staticmethod
def _test_pinverse(self, conv_fn):
def run_test(M):
Expand Down Expand Up @@ -6845,6 +6913,30 @@ def cholesky_solve_test_helper(A_dims, b_dims, cast, upper):
def test_cholesky_solve_batched(self):
self._test_cholesky_solve_batched(self, lambda t: t)

@staticmethod
def _test_cholesky_solve_batched_many_batches(self, cast):
from common_utils import random_symmetric_pd_matrix

def cholesky_solve_test_helper(A_dims, b_dims, cast, upper):
A = cast(random_symmetric_pd_matrix(*A_dims))
L = torch.cholesky(A, upper)
b = cast(torch.randn(*b_dims))
return A, L, b

for upper in [True, False]:
A, L, b = cholesky_solve_test_helper((5, 256, 256), (5, 10), cast, upper)
x = torch.cholesky_solve(b, L, upper)
self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 10)))

A, L, b = cholesky_solve_test_helper((5,), (512, 512, 5, 10), cast, upper)
x = torch.cholesky_solve(b, L, upper)
self.assertEqual(torch.matmul(A, x), b)

@skipIfNoLapack
@slowTest
def test_cholesky_solve_batched_many_batches(self):
self._test_cholesky_solve_batched_many_batches(self, lambda t: t)

@staticmethod
def _test_cholesky_solve_batched_dims(self, cast):
if not TEST_NUMPY:
Expand Down