Skip to content

Commit 7162649

Browse files
zou3519soumith
authored andcommitted
Add batched linear solver to torch.gesv() (#6100)
* Add batched linear solver to torch.gesv() Fixes #3164 Picks up from #4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that. * Tests and docs * Address comments * Address comments * Rebase * Address comments * Fix rebase * Addressed comments * Address comments * Address comments * Addressed comments
1 parent f598ef9 commit 7162649

File tree

13 files changed

+510
-14
lines changed

13 files changed

+510
-14
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3004,7 +3004,8 @@
30043004
- THTensor* tensor2
30053005
]]
30063006
[[
3007-
name: gesv
3007+
name: _gesv_single
3008+
cname: gesv
30083009
types:
30093010
- Float
30103011
- Double

aten/src/ATen/native/Gesv.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#include "ATen/ATen.h"
2+
#include "ATen/CPUApplyUtils.h"
3+
#include "ATen/Dispatch.h"
4+
#include "ATen/ExpandUtils.h"
5+
#include "ATen/NativeFunctions.h"
6+
7+
#include "ATen/native/LinearAlgebraUtils.h"
8+
#include "ATen/native/Gesv.h"
9+
10+
#include "TH.h" // for USE_LAPACK
11+
12+
#include <vector>
13+
14+
#ifdef USE_LAPACK
15+
extern "C" void dgesv_(
16+
int* n, int* nrhs, double* a, int* lda,
17+
int *ipiv, double* b, int* ldb, int* info);
18+
extern "C" void sgesv_(
19+
int* n, int* nrhs, float* a, int* lda,
20+
int* ipiv, float* b, int* ldb, int* info);
21+
#endif
22+
23+
namespace at { namespace native {
24+
25+
template<class scalar_t>
26+
void lapackGesv(
27+
int n, int nrhs, scalar_t* a, int lda, int* ipiv,
28+
scalar_t* b, int ldb, int* info) {
29+
AT_ERROR("gesv only takes float or double Tensors");
30+
}
31+
32+
#ifdef USE_LAPACK
33+
template<> void lapackGesv<float>(
34+
int n, int nrhs, float* a, int lda, int* ipiv,
35+
float* b, int ldb, int* info) {
36+
sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
37+
}
38+
39+
template<> void lapackGesv<double>(
40+
int n, int nrhs, double* a, int lda, int* ipiv,
41+
double* b, int ldb, int* info) {
42+
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
43+
}
44+
#endif
45+
46+
template <typename scalar_t>
47+
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
48+
#ifndef USE_LAPACK
49+
AT_ERROR("gesv: LAPACK library not found in compilation");
50+
#endif
51+
auto A_data = A.data<scalar_t>();
52+
auto b_data = b.data<scalar_t>();
53+
auto A_mat_stride = matrixStride(A);
54+
auto b_mat_stride = matrixStride(b);
55+
56+
auto batch_size = batchCount(A);
57+
auto n = A.size(-2);
58+
auto nrhs = b.size(-1);
59+
60+
auto ipiv = b.type().toScalarType(kInt).tensor(n);
61+
62+
for (int64_t i = 0; i < batch_size; i++) {
63+
int info;
64+
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
65+
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
66+
lapackGesv<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(),
67+
b_working_ptr, n, &info);
68+
infos[i] = info;
69+
if (info != 0) {
70+
return;
71+
}
72+
}
73+
}
74+
75+
std::tuple<Tensor,Tensor> _gesv_helper_cpu(const Tensor& self, const Tensor& A) {
76+
std::vector<int64_t> infos(batchCount(A), 0);
77+
auto A_working_copy = cloneBatchedColumnMajor(A);
78+
auto b_working_copy = cloneBatchedColumnMajor(self);
79+
AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
80+
applyGesv<scalar_t>(b_working_copy, A_working_copy, infos);
81+
});
82+
checkErrors(infos);
83+
return std::tuple<Tensor,Tensor>(b_working_copy, A_working_copy);
84+
}
85+
86+
// Supports arbitrary batch dimensions for self and A
87+
std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
88+
if (self.dim() <= 2 && A.dim() <= 2) {
89+
// TODO: #7102: It's not necessary to have gesv (single) bindings for both
90+
// TH and ATen. We should remove the TH gesv bindings, especially
91+
// since the lapackGesv function is already in ATen.
92+
return at::_gesv_single(self, A);
93+
}
94+
95+
checkInputs(self, A);
96+
97+
// broadcast the batch dimensions of self and A.
98+
IntList self_batch_sizes(self.sizes().data(), self.ndimension() - 2);
99+
IntList A_batch_sizes(A.sizes().data(), A.ndimension() - 2);
100+
std::vector<int64_t> expand_batch_portion =
101+
infer_size(self_batch_sizes, A_batch_sizes);
102+
103+
std::vector<int64_t> self_expand_size({expand_batch_portion});
104+
self_expand_size.insert(self_expand_size.end(),
105+
{ self.size(-2), self.size(-1) });
106+
107+
std::vector<int64_t> A_expand_size({expand_batch_portion});
108+
A_expand_size.insert(A_expand_size.end(),
109+
{ A.size(-2), A.size(-1) });
110+
111+
Tensor self_broadcasted = self.expand(self_expand_size);
112+
Tensor A_broadcasted = A.expand(A_expand_size);
113+
return self.type()._gesv_helper(self_broadcasted, A_broadcasted);
114+
}
115+
116+
std::tuple<Tensor&,Tensor&> gesv_out(
117+
Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) {
118+
if (self.dim() > 2 || A.dim() > 2) {
119+
AT_ERROR("torch.gesv() with the `out` keyword does not support batching. "
120+
"b.dim() (%lld) and A.dim() (%lld) must both be 2.",
121+
(long long)self.dim(), (long long)A.dim());
122+
}
123+
return at::_gesv_single_out(solution, lu, self, A);
124+
}
125+
126+
}} // namespace at::native

aten/src/ATen/native/Gesv.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "ATen/ATen.h"
2+
3+
namespace at { namespace native {
4+
5+
static inline void checkInputs(const Tensor& self, const Tensor& A) {
6+
if (A.size(-1) != A.size(-2)) {
7+
AT_ERROR("A must be batches of square matrices, "
8+
"but they are %lld by %lld matrices",
9+
(long long)A.size(-1), (long long)A.size(-2));
10+
}
11+
if (A.size(-1) != self.size(-2)) {
12+
AT_ERROR("Incompatible matrix sizes for matmul: each A "
13+
"matrix is %llu by %lld but each b matrix is %lld by %lld.",
14+
(long long)A.size(-1), (long long)A.size(-1),
15+
(long long)self.size(-2), (long long)self.size(-1));
16+
}
17+
}
18+
19+
static inline void checkErrors(std::vector<int64_t> infos) {
20+
for (size_t i = 0; i < infos.size(); i++) {
21+
auto info = infos[i];
22+
if (info < 0) {
23+
AT_ERROR("gesv: For batch %lld: Argument %lld has illegal value",
24+
(long long)i, -info);
25+
} else if (info > 0) {
26+
AT_ERROR("gesv: For batch %lld: U(%lld,%lld) is zero, singular U.",
27+
(long long)i, info, info);
28+
}
29+
}
30+
}
31+
32+
}} // namespace at::native
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include "ATen/ATen.h"
2+
3+
namespace at { namespace native {
4+
5+
/*
6+
* Clones a Tensor so that the following conditions hold:
7+
* If we think of a Tensor of having size (B, M, N), where B is any number
8+
* of batch dimensions, then:
9+
* - Each (M, N) matrix is in column major form
10+
* - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
11+
* Then when laid out in memory, the M by N matrix starting at
12+
* P.data_ptr()[b * M * N] is of the same corresponding batch as the M' by N'
13+
* matrix starting at Q.data_ptr()[b * M' * N'].
14+
*/
15+
static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
16+
// If src is already in batched column major format, then
17+
// this will be efficient (no reordering of the data will occur)
18+
// because the first transpose will make the tensor contiguous,
19+
// and cloning a contiguous tensor is fast.
20+
auto result = src.transpose(-2, -1).clone();
21+
result.transpose_(-2, -1);
22+
return result;
23+
}
24+
25+
/*
26+
* Given batches of matrices with arbitrary batch dim,
27+
* computes the number of batches.
28+
*/
29+
static inline int64_t batchCount(const Tensor& batched_matrices) {
30+
int64_t result = 1;
31+
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
32+
result *= batched_matrices.size(i);
33+
}
34+
return result;
35+
}
36+
37+
// Computes the number of elements of a matrix in a batched matrix tensor
38+
static inline int64_t matrixStride(const Tensor& batched_matrices) {
39+
return batched_matrices.size(-1) * batched_matrices.size(-2);
40+
}
41+
42+
}} // namespace at::native

aten/src/ATen/native/cuda/Gesv.cu

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#include "ATen/Context.h"
2+
#include "ATen/Dispatch.h"
3+
#include "ATen/NativeFunctions.h"
4+
#include "ATen/PinnedMemoryAllocator.h"
5+
#include "ATen/cuda/CUDAApplyUtils.cuh"
6+
7+
#include "ATen/native/LinearAlgebraUtils.h"
8+
#include "ATen/native/Gesv.h"
9+
10+
#include "THC.h" // for USE_MAGMA
11+
12+
#ifdef USE_MAGMA
13+
#include <magma.h>
14+
#include <magma_types.h>
15+
#endif
16+
17+
namespace at {
18+
namespace native {
19+
20+
#ifdef USE_MAGMA
21+
template<class scalar_t>
22+
void magmaGesvBatched(
23+
magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda,
24+
magma_int_t** dipiv_array, scalar_t** dB_array, magma_int_t lddb,
25+
magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) {
26+
AT_ERROR("gesv only takes float or double Tensors");
27+
}
28+
29+
template<>
30+
void magmaGesvBatched<float>(
31+
magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda,
32+
magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb,
33+
magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) {
34+
magma_sgesv_batched(
35+
n, nrhs, dA_array, ldda, dipiv_array,
36+
dB_array, lddb, dinfo_array, batch_count, queue);
37+
}
38+
39+
template<>
40+
void magmaGesvBatched<double>(
41+
magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
42+
magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb,
43+
magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) {
44+
magma_dgesv_batched(
45+
n, nrhs, dA_array, ldda, dipiv_array,
46+
dB_array, lddb, dinfo_array, batch_count, queue);
47+
}
48+
49+
static magma_queue_t createMagmaQueue(const Tensor& tensor) {
50+
auto& context = tensor.type().get_context();
51+
magma_queue_t magma_queue;
52+
magma_queue_create_from_cuda(
53+
tensor.get_device(),
54+
context.getCurrentCUDAStream(),
55+
THCState_getCurrentBlasHandle(context.thc_state),
56+
THCState_getCurrentSparseHandle(context.thc_state),
57+
&magma_queue);
58+
return magma_queue;
59+
}
60+
#endif
61+
62+
static inline magma_int_t magma_int_cast(int64_t value, const char* varname) {
63+
auto result = static_cast<magma_int_t>(value);
64+
if (static_cast<int64_t>(result) != value) {
65+
AT_ERROR("magma: The value of %s (%lld) is too large to fit into a magma_int_t (%llu bytes)",
66+
varname, (long long)value, sizeof(magma_int_t));
67+
}
68+
return result;
69+
}
70+
71+
// Creates an array of size elements of type T, backed by pinned memory
72+
// wrapped in a Storage
73+
template<class T>
74+
static inline std::unique_ptr<Storage> pin_memory(int64_t size, Tensor dummy) {
75+
int64_t adjusted_size = size * sizeof(T);
76+
auto allocator = std::unique_ptr<Allocator>(new PinnedMemoryAllocator());
77+
auto& backend = dummy.type().toBackend(kCPU).toScalarType(kByte);
78+
return backend.storageWithAllocator(adjusted_size, std::move(allocator));
79+
}
80+
81+
#define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \
82+
auto storage_##name = pin_memory<type>(size, dummy_tensor); \
83+
name = reinterpret_cast<type*>(storage_##name->data());
84+
85+
template <typename scalar_t>
86+
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
87+
#ifndef USE_MAGMA
88+
AT_ERROR("gesv: MAGMA library not found in "
89+
"compilation. Please rebuild with MAGMA.");
90+
#else
91+
auto A_data = A.data<scalar_t>();
92+
auto b_data = b.data<scalar_t>();
93+
auto A_mat_stride = matrixStride(A);
94+
auto b_mat_stride = matrixStride(b);
95+
96+
magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
97+
magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
98+
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
99+
100+
magma_int_t* info_array;
101+
magma_int_t* ipiv_data;
102+
magma_int_t** ipiv_array;
103+
scalar_t** A_array;
104+
scalar_t** b_array;
105+
106+
ALLOCATE_ARRAY(info_array, magma_int_t, batch_size, b);
107+
ALLOCATE_ARRAY(ipiv_data, magma_int_t, batch_size * n, b);
108+
ALLOCATE_ARRAY(ipiv_array, magma_int_t*, batch_size, b);
109+
ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
110+
ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);
111+
112+
// Set up the created arrays
113+
for (int64_t i = 0; i < batch_size; i++) {
114+
A_array[i] = &A_data[i * A_mat_stride];
115+
b_array[i] = &b_data[i * b_mat_stride];
116+
ipiv_array[i] = &ipiv_data[i * n];
117+
}
118+
119+
magmaGesvBatched<scalar_t>(
120+
n, nrhs, A_array, n, ipiv_array, b_array, n,
121+
info_array, batch_size, createMagmaQueue(b));
122+
123+
for (int64_t i = 0; i < batch_size; i++) {
124+
infos[i] = info_array[i];
125+
}
126+
#endif
127+
}
128+
129+
std::tuple<Tensor,Tensor> _gesv_helper_cuda(const Tensor& self, const Tensor& A) {
130+
std::vector<int64_t> infos(batchCount(A), 0);
131+
auto A_working_copy = cloneBatchedColumnMajor(A);
132+
auto b_working_copy = cloneBatchedColumnMajor(self);
133+
AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
134+
applyGesv<scalar_t>(b_working_copy, A_working_copy, infos);
135+
});
136+
checkErrors(infos);
137+
return std::tuple<Tensor,Tensor>(b_working_copy, A_working_copy);
138+
}
139+
140+
}} // namespace at::native
141+
142+
#undef ALLOCATE_ARRAY

aten/src/ATen/native/native_functions.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,17 @@
494494
- func: ger_out(Tensor result, Tensor self, Tensor vec2) -> Tensor
495495
variants: function
496496

497+
- func: gesv(Tensor self, Tensor A) -> (Tensor, Tensor)
498+
499+
- func: gesv_out(Tensor solution, Tensor lu, Tensor self, Tensor A) -> (Tensor, Tensor)
500+
variants: function
501+
502+
# gesv handles broadcasting of arbitrary batch dims while _gesv_helper does not.
503+
- func: _gesv_helper(Tensor self, Tensor A) -> (Tensor, Tensor)
504+
dispatch:
505+
CPU: _gesv_helper_cpu
506+
CUDA: _gesv_helper_cuda
507+
497508
- func: group_norm(Tensor input, int64_t num_groups, Tensor? weight={}, Tensor? bias={}, double eps=1e-5, bool cudnn_enabled=True) -> Tensor
498509
variants: function
499510

test/test_autograd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2869,6 +2869,10 @@ class dont_convert(tuple):
28692869
('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS,
28702870
'large', NO_ARGS, [skipIfNoLapack]),
28712871
('gesv', (S, S), ((S, S),), '', NO_ARGS, [skipIfNoLapack]),
2872+
('gesv', (S, S, S), ((S, S, S),), 'batched', NO_ARGS, [skipIfNoLapack]),
2873+
('gesv', (2, 3, S, S), ((2, 3, S, S),), 'batched_dims', NO_ARGS, [skipIfNoLapack]),
2874+
('gesv', (2, 2, S, S), ((1, S, S),), 'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]),
2875+
('gesv', (1, S, S), ((2, 2, S, S),), 'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]),
28722876
('fill_', (S, S, S), (1,), 'number'),
28732877
('fill_', (), (1,), 'number_scalar'),
28742878
# FIXME: we should compute the derivative w.r.t torch.tensor(1)

test/test_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,6 +1320,14 @@ def _select_broadcastable_dims(dims_full=None):
13201320
def test_det_logdet_slogdet(self):
13211321
TestTorch._test_det_logdet_slogdet(self, lambda t: t.cuda())
13221322

1323+
@unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected")
1324+
def test_gesv_batched(self):
1325+
TestTorch._test_gesv_batched(self, lambda t: t.cuda())
1326+
1327+
@unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected")
1328+
def test_gesv_batched_dims(self):
1329+
TestTorch._test_gesv_batched_dims(self, lambda t: t.cuda())
1330+
13231331
def test_view(self):
13241332
TestTorch._test_view(self, lambda t: t.cuda())
13251333

0 commit comments

Comments
 (0)