Skip to content

Commit 7b2fb01

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Make potrs batched (#13453)
Summary: - This is a straightforward PR, building up on the batch inverse PR, except for one change: - The GENERATE_LINALG_HELPER_n_ARGS macro has been removed, since it is not very general and the resulting code is actually not very copy-pasty. Billing of changes: - Add batching for `potrs` - Add relevant tests - Modify doc string Minor changes: - Remove `_gesv_single`, `_getri_single` from `aten_interned_strings.h`. - Add test for CUDA `potrs` (2D Tensor op) - Move the batched shape checking to `LinearAlgebraUtils.h` Pull Request resolved: #13453 Reviewed By: soumith Differential Revision: D12942039 Pulled By: zou3519 fbshipit-source-id: 1b8007f00218e61593fc415865b51c1dac0b6a35
1 parent e3e6ca1 commit 7b2fb01

File tree

13 files changed

+358
-74
lines changed

13 files changed

+358
-74
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2671,7 +2671,7 @@
26712671
default: U
26722672
]]
26732673
[[
2674-
name: _th_potrs
2674+
name: _th_potrs_single
26752675
cname: potrs
26762676
types:
26772677
- Float

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ _(aten, _floor) \
8181
_(aten, _fused_dropout) \
8282
_(aten, _ger) \
8383
_(aten, _gesv_helper) \
84-
_(aten, _gesv_single) \
85-
_(aten, _getri_single) \
8684
_(aten, _indexCopy) \
8785
_(aten, _indices) \
8886
_(aten, _inverse_helper) \
@@ -103,6 +101,7 @@ _(aten, _pack_padded_sequence_backward) \
103101
_(aten, _pad_packed_sequence) \
104102
_(aten, _pdist_backward) \
105103
_(aten, _pdist_forward) \
104+
_(aten, _potrs_helper) \
106105
_(aten, _prod) \
107106
_(aten, _prodall) \
108107
_(aten, _range) \

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "ATen/ATen.h"
22
#include "ATen/CPUApplyUtils.h"
33
#include "ATen/Dispatch.h"
4-
#include "ATen/ExpandUtils.h"
54
#include "ATen/NativeFunctions.h"
65

76
#include "ATen/native/LinearAlgebraUtils.h"
@@ -16,14 +15,18 @@
1615
#ifdef USE_LAPACK
1716

1817
// gesv
19-
extern "C" void dgesv_(int* n, int* nrhs, double* a, int* lda, int *ipiv, double* b, int* ldb, int* info);
20-
extern "C" void sgesv_(int* n, int* nrhs, float* a, int* lda, int* ipiv, float* b, int* ldb, int* info);
18+
extern "C" void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
19+
extern "C" void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
2120

2221
// inverse
2322
extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
2423
extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);
2524
extern "C" void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info);
2625
extern "C" void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);
26+
27+
// potrs
28+
extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
29+
extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
2730
#endif
2831

2932
namespace at {
@@ -32,12 +35,12 @@ namespace native {
3235
// Define the per-batch functions to be used in the main implementation of the batched
3336
// linear algebra operations
3437
template<class scalar_t>
35-
void lapackGesv(int n, int nrhs, scalar_t* a, int lda, int* ipiv, scalar_t* b, int ldb, int* info) {
38+
void lapackGesv(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info) {
3639
AT_ERROR("gesv only takes float or double Tensors");
3740
}
3841

3942
template<class scalar_t>
40-
void lapackGetrf(int m, int n, scalar_t* a, int lda, int *ipiv, int *info) {
43+
void lapackGetrf(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
4144
AT_ERROR("getrf only takes float or double Tensors");
4245
}
4346

@@ -46,12 +49,17 @@ void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwo
4649
AT_ERROR("getri only takes float or double Tensors");
4750
}
4851

52+
template<class scalar_t>
53+
void lapackPotrs(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
54+
AT_ERROR("potrs only takes float or double Tensors");
55+
}
56+
4957
#ifdef USE_LAPACK
50-
template<> void lapackGesv<double>(int n, int nrhs, double* a, int lda, int* ipiv, double* b, int ldb, int* info) {
58+
template<> void lapackGesv<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
5159
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
5260
}
5361

54-
template<> void lapackGesv<float>(int n, int nrhs, float* a, int lda, int* ipiv, float* b, int ldb, int* info) {
62+
template<> void lapackGesv<float>(int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) {
5563
sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
5664
}
5765

@@ -70,6 +78,14 @@ template<> void lapackGetrf<double>(int m, int n, double *a, int lda, int *ipiv,
7078
template<> void lapackGetrf<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
7179
sgetrf_(&m, &n, a, &lda, ipiv, info);
7280
}
81+
82+
template<> void lapackPotrs<double>(char uplo, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
83+
dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
84+
}
85+
86+
template<> void lapackPotrs<float>(char uplo, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
87+
spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
88+
}
7389
#endif
7490

7591
// Below of the definitions of the functions operating on a batch that are going to be dispatched
@@ -105,8 +121,16 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
105121
}
106122
}
107123

108-
// These utilities are specified in LinearAlgebraUtils.h
109-
GENERATE_LINALG_HELPER_2_ARGS(gesv, self, A, cpu)
124+
std::tuple<Tensor, Tensor> _gesv_helper_cpu(const Tensor& self, const Tensor& A) {
125+
std::vector<int64_t> infos(batchCount(self), 0);
126+
auto self_working_copy = cloneBatchedColumnMajor(self);
127+
auto A_working_copy = cloneBatchedColumnMajor(A);
128+
AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
129+
apply_gesv<scalar_t>(self_working_copy, A_working_copy, infos);
130+
});
131+
batchCheckErrors(infos, "gesv");
132+
return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
133+
}
110134

111135
// Supports arbitrary batch dimensions for self and A
112136
std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
@@ -117,21 +141,8 @@ std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
117141
return at::_th_gesv_single(self, A);
118142
}
119143

120-
gesvCheckInputs(self, A);
121-
122-
// broadcast the batch dimensions of self and A.
123-
IntList self_batch_sizes(self.sizes().data(), self.ndimension() - 2);
124-
IntList A_batch_sizes(A.sizes().data(), A.ndimension() - 2);
125-
std::vector<int64_t> expand_batch_portion = infer_size(self_batch_sizes, A_batch_sizes);
126-
127-
std::vector<int64_t> self_expand_size({expand_batch_portion});
128-
self_expand_size.insert(self_expand_size.end(), { self.size(-2), self.size(-1) });
129-
130-
std::vector<int64_t> A_expand_size({expand_batch_portion});
131-
A_expand_size.insert(A_expand_size.end(), { A.size(-2), A.size(-1) });
132-
133-
Tensor self_broadcasted = self.expand(self_expand_size);
134-
Tensor A_broadcasted = A.expand(A_expand_size);
144+
Tensor self_broadcasted, A_broadcasted;
145+
std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
135146
return at::_gesv_helper(self_broadcasted, A_broadcasted);
136147
}
137148

@@ -185,7 +196,15 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
185196
}
186197
}
187198

188-
GENERATE_LINALG_HELPER_1_ARGS(inverse, self, cpu)
199+
Tensor _inverse_helper_cpu(const Tensor& self) {
200+
std::vector<int64_t> infos(batchCount(self), 0);
201+
auto self_working_copy = cloneBatchedColumnMajor(self);
202+
AT_DISPATCH_FLOATING_TYPES(self.type(), "inverse", [&]{
203+
apply_inverse<scalar_t>(self_working_copy, infos);
204+
});
205+
batchCheckErrors(infos, "inverse");
206+
return self_working_copy;
207+
}
189208

190209
Tensor inverse(const Tensor &self) {
191210
if (self.size(-1) == 0) {
@@ -206,4 +225,63 @@ Tensor& inverse_out(Tensor &result, const Tensor &self) {
206225
return result;
207226
}
208227

228+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ potrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229+
230+
template<typename scalar_t>
231+
static void apply_potrs(Tensor& b, Tensor& A, bool upper, std::vector<int64_t>& infos) {
232+
#ifndef USE_LAPACK
233+
AT_ERROR("potrs: LAPACK library not found in compilation");
234+
#endif
235+
char uplo = upper ? 'U' : 'L';
236+
237+
auto A_data = A.data<scalar_t>();
238+
auto b_data = b.data<scalar_t>();
239+
auto A_mat_stride = matrixStride(A);
240+
auto b_mat_stride = matrixStride(b);
241+
242+
auto batch_size = batchCount(A);
243+
auto n = A.size(-2);
244+
auto nrhs = b.size(-1);
245+
246+
for (int64_t i = 0; i < batch_size; i++) {
247+
int info;
248+
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
249+
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
250+
lapackPotrs<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
251+
infos[i] = info;
252+
if (info != 0) {
253+
return;
254+
}
255+
}
256+
}
257+
258+
Tensor _potrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
259+
std::vector<int64_t> infos(batchCount(self), 0);
260+
auto self_working_copy = cloneBatchedColumnMajor(self);
261+
auto A_working_copy = cloneBatchedColumnMajor(A);
262+
AT_DISPATCH_FLOATING_TYPES(self.type(), "potrs", [&]{
263+
apply_potrs<scalar_t>(self_working_copy, A_working_copy, upper, infos);
264+
});
265+
batchCheckErrors(infos, "potrs");
266+
return self_working_copy;
267+
}
268+
269+
// Supports arbitrary batch dimensions for self and A
270+
Tensor potrs(const Tensor& self, const Tensor& A, bool upper) {
271+
if (self.dim() <= 2 && A.dim() <= 2) {
272+
return at::_th_potrs_single(self, A, upper);
273+
}
274+
275+
Tensor self_broadcasted, A_broadcasted;
276+
std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
277+
return at::_potrs_helper(self_broadcasted, A_broadcasted, upper);
278+
}
279+
280+
Tensor& potrs_out(Tensor& result, const Tensor& self, const Tensor& A, bool upper) {
281+
AT_CHECK(self.dim() == 2 && A.dim() == 2,
282+
"torch.potrs() with the `out` keyword does not support batching. "
283+
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
284+
return at::_th_potrs_single_out(result, self, A, upper);
285+
}
286+
209287
}} // namespace at::native

aten/src/ATen/native/LegacyDefinitions.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -491,14 +491,6 @@ Tensor cholesky(const Tensor & self, bool upper) {
491491
return at::_th_potrf(self, upper);
492492
}
493493

494-
Tensor & potrs_out(Tensor & result, const Tensor & self, const Tensor & input2, bool upper) {
495-
return at::_th_potrs_out(result, self, input2, upper);
496-
}
497-
498-
Tensor potrs(const Tensor & self, const Tensor & input2, bool upper) {
499-
return at::_th_potrs(self, input2, upper);
500-
}
501-
502494
Tensor & potri_out(Tensor & result, const Tensor & self, bool upper) {
503495
return at::_th_potri_out(result, self, upper);
504496
}

aten/src/ATen/native/LinearAlgebraUtils.h

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "ATen/ATen.h"
2+
#include "ATen/ExpandUtils.h"
23
#include <limits>
34

45
namespace at { namespace native {
@@ -52,8 +53,8 @@ static inline double _get_epsilon(const ScalarType& sc_type) {
5253
}
5354
}
5455

55-
// Validates input shapes for gesv
56-
static inline void gesvCheckInputs(const Tensor& self, const Tensor& A) {
56+
// Validates input shapes for linear solve methods (gesv, potrs)
57+
static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
5758
AT_CHECK(A.size(-1) == A.size(-2),
5859
"A must be batches of square matrices, "
5960
"but they are ", A.size(-1), " by ", A.size(-2), " matrices");
@@ -87,34 +88,30 @@ static inline void batchCheckErrors(std::vector<int64_t>& infos, const char* nam
8788
}
8889
}
8990

90-
#define GENERATE_LINALG_HELPER_1_ARGS(NAME, ARG, BACKEND) \
91-
Tensor _##NAME##_helper_##BACKEND(const Tensor& ARG) { \
92-
std::vector<int64_t> infos(batchCount(ARG), 0); \
93-
auto ARG##_working_copy = cloneBatchedColumnMajor(ARG); \
94-
AT_DISPATCH_FLOATING_TYPES(ARG.type(), #NAME, [&]{ \
95-
apply_##NAME<scalar_t>(ARG##_working_copy, infos); \
96-
}); \
97-
batchCheckErrors(infos, #NAME); \
98-
return ARG##_working_copy; \
99-
}
100-
101-
#define GENERATE_LINALG_HELPER_2_ARGS(NAME, ARG1, ARG2, BACKEND) \
102-
std::tuple<Tensor, Tensor> _##NAME##_helper_##BACKEND(const Tensor& ARG1, const Tensor& ARG2) { \
103-
std::vector<int64_t> infos(batchCount(ARG1), 0); \
104-
auto ARG1##_working_copy = cloneBatchedColumnMajor(ARG1); \
105-
auto ARG2##_working_copy = cloneBatchedColumnMajor(ARG2); \
106-
AT_DISPATCH_FLOATING_TYPES(ARG1.type(), #NAME, [&]{ \
107-
apply_##NAME<scalar_t>(ARG1##_working_copy, ARG2##_working_copy, infos); \
108-
}); \
109-
batchCheckErrors(infos, #NAME); \
110-
return std::tuple<Tensor, Tensor>(ARG1##_working_copy, ARG2##_working_copy); \
111-
}
112-
11391
// Checks if all the Tensors in a TensorList are of the same dimensions
11492
static inline void checkAllSameDim(TensorList tensors, int64_t dim) {
11593
for (auto &t : tensors) {
11694
AT_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
11795
}
11896
}
11997

98+
static inline std::tuple<Tensor,Tensor> _linear_solve_broadcast_args(const Tensor& arg1, const Tensor& arg2) {
99+
linearSolveCheckInputs(arg1, arg2);
100+
101+
// broadcast the batch dimensions of arg1 and arg2.
102+
IntList arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
103+
IntList arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
104+
std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
105+
106+
std::vector<int64_t> arg1_expand_size({expand_batch_portion});
107+
arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
108+
109+
std::vector<int64_t> arg2_expand_size({expand_batch_portion});
110+
arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
111+
112+
Tensor arg1_broadcasted = arg1.expand(arg1_expand_size);
113+
Tensor arg2_broadcasted = arg2.expand(arg2_expand_size);
114+
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
115+
}
116+
120117
}} // namespace at::native

0 commit comments

Comments
 (0)