Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,13 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
auto nrhs = b.size(-1);

auto ipiv = at::empty({n}, b.options().dtype(kInt));
auto ipiv_data = ipiv.data<int>();

int info;
for (int64_t i = 0; i < batch_size; i++) {
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
lapackSolve<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(), b_working_ptr, n, &info);
lapackSolve<scalar_t>(n, nrhs, A_working_ptr, n, ipiv_data, b_working_ptr, n, &info);
infos[i] = info;
if (info != 0) {
return;
Expand Down Expand Up @@ -206,28 +207,30 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
auto n = self.size(-2);

auto ipiv = at::empty({n}, self.options().dtype(kInt));
int lwork;
scalar_t wkopt;
Tensor work;
auto ipiv_data = ipiv.data<int>();

int info;
// Run once, first to get the optimum work size
// Since we deal with batches of matrices with the same dimensions, doing this outside
// the loop saves (batch_size - 1) workspace queries which would provide the same result
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
int lwork = -1;
scalar_t wkopt;
lapackGetri<scalar_t>(n, self_data, n, ipiv_data, &wkopt, lwork, &info);
lwork = static_cast<int>(wkopt);
Tensor work = at::empty({lwork}, self.options());
auto work_data = work.data<scalar_t>();

for (int64_t i = 0; i < batch_size; i++) {
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
lapackLu<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
lapackLu<scalar_t>(n, n, self_working_ptr, n, ipiv_data, &info);
infos[i] = info;
if (info != 0) {
return;
}

// Run twice, first to get the optimum work size
lwork = -1;
lapackGetri<scalar_t>(n, self_working_ptr, n, ipiv.data<int>(), &wkopt, lwork, &info);

lwork = static_cast<int>(wkopt);
work = at::empty({lwork}, self.options());

// now to compute the actual inverse
lapackGetri<scalar_t>(n, self_working_ptr, n, ipiv.data<int>(), work.data<scalar_t>(), lwork, &info);
// now compute the actual inverse
lapackGetri<scalar_t>(n, self_working_ptr, n, ipiv_data, work_data, lwork, &info);
infos[i] = info;
if (info != 0) {
return;
Expand Down