Skip to content

Commit 8c4b2a8

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Remove extra workspace queries in matrix inverse computation (#20904)
Summary: Earlier, the workspace size query and allocation was placed inside the loop. However, since we have batches of matrices with the same number of rows and columns, the workspace size query and allocation for every matrix in the batch is redundant. This PR moves the workspace size query and allocation outside the loop, effectively saving (batch_size - 1) number of queries and allocation (and consequently the deallocation). There is a tremendous speedup in inverse computation as a result of this change. Changelog: - Move workspace query and allocation outside the batch loop Pull Request resolved: #20904 Differential Revision: D15495505 Pulled By: ezyang fbshipit-source-id: 226729734465fcaf896f86e1b1a548a81440e082
1 parent 4109ec1 commit 8c4b2a8

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,13 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
145145
auto nrhs = b.size(-1);
146146

147147
auto ipiv = at::empty({n}, b.options().dtype(kInt));
148+
auto ipiv_data = ipiv.data<int>();
148149

149150
int info;
150151
for (int64_t i = 0; i < batch_size; i++) {
151152
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
152153
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
153-
lapackSolve<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(), b_working_ptr, n, &info);
154+
lapackSolve<scalar_t>(n, nrhs, A_working_ptr, n, ipiv_data, b_working_ptr, n, &info);
154155
infos[i] = info;
155156
if (info != 0) {
156157
return;
@@ -206,28 +207,30 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
206207
auto n = self.size(-2);
207208

208209
auto ipiv = at::empty({n}, self.options().dtype(kInt));
209-
int lwork;
210-
scalar_t wkopt;
211-
Tensor work;
210+
auto ipiv_data = ipiv.data<int>();
212211

213212
int info;
213+
// Run once, first to get the optimum work size
214+
// Since we deal with batches of matrices with the same dimensions, doing this outside
215+
// the loop saves (batch_size - 1) workspace queries which would provide the same result
216+
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
217+
int lwork = -1;
218+
scalar_t wkopt;
219+
lapackGetri<scalar_t>(n, self_data, n, ipiv_data, &wkopt, lwork, &info);
220+
lwork = static_cast<int>(wkopt);
221+
Tensor work = at::empty({lwork}, self.options());
222+
auto work_data = work.data<scalar_t>();
223+
214224
for (int64_t i = 0; i < batch_size; i++) {
215225
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
216-
lapackLu<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
226+
lapackLu<scalar_t>(n, n, self_working_ptr, n, ipiv_data, &info);
217227
infos[i] = info;
218228
if (info != 0) {
219229
return;
220230
}
221231

222-
// Run twice, first to get the optimum work size
223-
lwork = -1;
224-
lapackGetri<scalar_t>(n, self_working_ptr, n, ipiv.data<int>(), &wkopt, lwork, &info);
225-
226-
lwork = static_cast<int>(wkopt);
227-
work = at::empty({lwork}, self.options());
228-
229-
// now to compute the actual inverse
230-
lapackGetri<scalar_t>(n, self_working_ptr, n, ipiv.data<int>(), work.data<scalar_t>(), lwork, &info);
232+
// now compute the actual inverse
233+
lapackGetri<scalar_t>(n, self_working_ptr, n, ipiv_data, work_data, lwork, &info);
231234
infos[i] = info;
232235
if (info != 0) {
233236
return;

0 commit comments

Comments
 (0)