Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
0096374
Add batched inverse
vishwakftw Jul 27, 2018
157280b
Remove pin_memory from Gesv.cu
vishwakftw Jul 27, 2018
bd43e52
Missed #endif
vishwakftw Jul 27, 2018
6ed2f09
Fix computation of inverses
vishwakftw Jul 27, 2018
0d5a133
Silly typographical error; my bad
vishwakftw Jul 27, 2018
47ffc82
Modify docs, derivative, and add test in test_autograd
vishwakftw Jul 28, 2018
69db465
Add tests in test_cuda, test_torch
vishwakftw Jul 28, 2018
bf165f1
Remove _batch_inverse from MVN
vishwakftw Jul 28, 2018
24b5d97
Fix include error
vishwakftw Jul 28, 2018
1dfabaf
Rename files for consistency
vishwakftw Jul 28, 2018
9a22157
Clean up error messages, pass infos by ref
vishwakftw Jul 28, 2018
5fa97b3
Modify test case for full rank matrices
vishwakftw Aug 2, 2018
cc6b9e4
Some changes
vishwakftw Aug 4, 2018
50ef95d
Merge branch 'master' into batch-inverse
vishwakftw Aug 4, 2018
5c78d86
exclude gesv and inverse functions from test_jit
vishwakftw Aug 4, 2018
53fb608
fix nit
vishwakftw Aug 4, 2018
752763f
Remove unnecessary Python bindings, fix inverse for empty tensors
vishwakftw Aug 4, 2018
fba711b
Merge branch 'master' into batch-inverse
vishwakftw Aug 21, 2018
b4ff55f
fix CUDA build failure
vishwakftw Aug 22, 2018
5f94114
Merge branch 'master' into batch-inverse
vishwakftw Aug 28, 2018
2c86c31
Update MiscUtils.h
vishwakftw Aug 28, 2018
a5b9fd8
Fix lint
vishwakftw Aug 28, 2018
639f938
Fix CUDA build failure
vishwakftw Aug 29, 2018
dcf99d0
Merge branch 'master' into batch-inverse
vishwakftw Sep 11, 2018
241ae6b
Fix some nits
vishwakftw Sep 11, 2018
61802a4
Merge branch 'batch-inverse' of github.com:vishwakftw/pytorch into ba…
vishwakftw Sep 11, 2018
5866244
Merge branch 'master' into batch-inverse
vishwakftw Sep 11, 2018
c293f40
Merge branch 'master' into batch-inverse
vishwakftw Sep 12, 2018
d383d9f
Remove check after batched getrf
vishwakftw Sep 12, 2018
4c8cb4a
self.type()._inverse_helper(self) --> at::_inverse_helper(self)
vishwakftw Sep 12, 2018
62102ff
Merge branch 'master' of https://github.com/pytorch/pytorch into batc…
vishwakftw Sep 15, 2018
574aaf9
Refactor batch linear operations
vishwakftw Sep 15, 2018
c188de4
Revert submodule change, and revert changes to test_jit
vishwakftw Sep 15, 2018
b1e44fb
Reorganize functionally
vishwakftw Sep 15, 2018
723e871
Merge branch 'master' of https://github.com/pytorch/pytorch into batc…
vishwakftw Oct 10, 2018
c0cf1d8
Fix one more conflict
vishwakftw Oct 10, 2018
dfc5d30
add _inverse_helper to aten_interned_strings.h
vishwakftw Oct 10, 2018
09671b0
Fix test error
vishwakftw Oct 10, 2018
a217271
Fix one more test
vishwakftw Oct 10, 2018
d7d3f1f
Merge branch 'master' into batch-inverse
vishwakftw Oct 17, 2018
a0e15f3
CR
vishwakftw Oct 17, 2018
e31f74f
Fix nits
vishwakftw Oct 18, 2018
3cb4485
Fix more nits
vishwakftw Oct 18, 2018
2611bed
Fix tests and CR
vishwakftw Oct 19, 2018
dc2e9ae
Merge branch 'master' into batch-inverse
vishwakftw Oct 19, 2018
981a712
Merge branch 'batch-inverse' of github.com:vishwakftw/pytorch into ba…
vishwakftw Oct 19, 2018
9e5acd7
Fix tests, build
vishwakftw Oct 19, 2018
cba5237
Undo submodule change, modify docs for matrix_power
vishwakftw Oct 19, 2018
f675f70
Merge branch 'master' of https://github.com/pytorch/pytorch into batc…
vishwakftw Oct 21, 2018
221209d
CR and improvements
vishwakftw Oct 22, 2018
ca96075
Add RAII struct for MAGMA queue management, revert getrf optimization
vishwakftw Oct 23, 2018
3dfae69
CR
vishwakftw Oct 24, 2018
7fb2196
Merge branch 'master' of https://github.com/pytorch/pytorch into batc…
vishwakftw Oct 24, 2018
8cc6504
Delete default constructor for MAGMAQueue
vishwakftw Oct 24, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2758,7 +2758,7 @@
default: S
]]
[[
name: _getri
name: _getri_single
cname: getri
types:
- Float
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ _(aten, _fused_dropout) \
_(aten, _ger) \
_(aten, _gesv_helper) \
_(aten, _gesv_single) \
_(aten, _getri) \
_(aten, _getri_single) \
_(aten, _indexCopy) \
_(aten, _indices) \
_(aten, _inverse_helper) \
_(aten, _linspace) \
_(aten, _local_scalar) \
_(aten, _local_scalar_dense) \
Expand Down
209 changes: 209 additions & 0 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
#include "ATen/ATen.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/Dispatch.h"
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"

#include "ATen/native/LinearAlgebraUtils.h"

#include "TH.h" // for USE_LAPACK

#include <vector>

// First the required LAPACK implementations are registered here.
// A comment above the registered LAPACK routine suggest which batched
// linear algebra function uses that routine
#ifdef USE_LAPACK

// gesv
extern "C" void dgesv_(int* n, int* nrhs, double* a, int* lda, int *ipiv, double* b, int* ldb, int* info);
extern "C" void sgesv_(int* n, int* nrhs, float* a, int* lda, int* ipiv, float* b, int* ldb, int* info);

// inverse
extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);
extern "C" void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info);
extern "C" void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);
#endif

namespace at {
namespace native {

// Define the per-batch functions to be used in the main implementation of the batched
// linear algebra operations
template<class scalar_t>
void lapackGesv(int n, int nrhs, scalar_t* a, int lda, int* ipiv, scalar_t* b, int ldb, int* info) {
AT_ERROR("gesv only takes float or double Tensors");
}

template<class scalar_t>
void lapackGetrf(int m, int n, scalar_t* a, int lda, int *ipiv, int *info) {
AT_ERROR("getrf only takes float or double Tensors");
}

template<class scalar_t>
void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwork, int *info) {
AT_ERROR("getri only takes float or double Tensors");
}

#ifdef USE_LAPACK
template<> void lapackGesv<double>(int n, int nrhs, double* a, int lda, int* ipiv, double* b, int ldb, int* info) {
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

template<> void lapackGesv<float>(int n, int nrhs, float* a, int lda, int* ipiv, float* b, int ldb, int* info) {
sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

template<> void lapackGetri<double>(int n, double *a, int lda, int *ipiv, double *work, int lwork, int *info) {
dgetri_(&n, a, &lda, ipiv, work, &lwork, info);
}

template<> void lapackGetri<float>(int n, float *a, int lda, int *ipiv, float *work, int lwork, int *info) {
sgetri_(&n, a, &lda, ipiv, work, &lwork, info);
}

template<> void lapackGetrf<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
dgetrf_(&m, &n, a, &lda, ipiv, info);
}

template<> void lapackGetrf<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
sgetrf_(&m, &n, a, &lda, ipiv, info);
}
#endif

// Below of the definitions of the functions operating on a batch that are going to be dispatched
// in the main helper functions for the linear algebra operations

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ gesv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
#ifndef USE_LAPACK
AT_ERROR("gesv: LAPACK library not found in compilation");
#endif
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);

auto batch_size = batchCount(A);
auto n = A.size(-2);
auto nrhs = b.size(-1);

auto ipiv = at::empty({n}, b.type().toScalarType(kInt));

for (int64_t i = 0; i < batch_size; i++) {
int info;
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
lapackGesv<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(), b_working_ptr, n, &info);
infos[i] = info;
if (info != 0) {
return;
}
}
}

// These utilities are specified in LinearAlgebraUtils.h
GENERATE_LINALG_HELPER_2_ARGS(gesv, self, A, cpu)

// Supports arbitrary batch dimensions for self and A
std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
if (self.dim() <= 2 && A.dim() <= 2) {
// TODO: #7102: It's not necessary to have gesv (single) bindings for both
// TH and ATen. We should remove the TH gesv bindings, especially
// since the lapackGesv function is already in ATen.
return at::_gesv_single(self, A);
}

gesvCheckInputs(self, A);

// broadcast the batch dimensions of self and A.
IntList self_batch_sizes(self.sizes().data(), self.ndimension() - 2);
IntList A_batch_sizes(A.sizes().data(), A.ndimension() - 2);
std::vector<int64_t> expand_batch_portion = infer_size(self_batch_sizes, A_batch_sizes);

std::vector<int64_t> self_expand_size({expand_batch_portion});
self_expand_size.insert(self_expand_size.end(), { self.size(-2), self.size(-1) });

std::vector<int64_t> A_expand_size({expand_batch_portion});
A_expand_size.insert(A_expand_size.end(), { A.size(-2), A.size(-1) });

Tensor self_broadcasted = self.expand(self_expand_size);
Tensor A_broadcasted = A.expand(A_expand_size);
return at::_gesv_helper(self_broadcasted, A_broadcasted);
}

std::tuple<Tensor&,Tensor&> gesv_out(Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) {
AT_CHECK(self.dim() == 2 && A.dim() == 2,
"torch.gesv() with the `out` keyword does not support batching. "
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
return at::_gesv_single_out(solution, lu, self, A);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t>
static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
#ifndef USE_LAPACK
AT_ERROR("inverse: LAPACK library not found in compilation");
#endif
auto self_data = self.data<scalar_t>();
auto self_matrix_stride = matrixStride(self);

auto batch_size = batchCount(self);
auto n = self.size(-2);

auto ipiv = at::empty({n}, self.type().toScalarType(kInt));
int lwork;
scalar_t wkopt;
Tensor work;

for (int64_t i = 0; i < batch_size; i++) {
int info;
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
lapackGetrf<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
infos[i] = info;
if (info != 0) {
return;
}

// Run twice, first to get the optimum work size
lwork = -1;
lapackGetri<scalar_t>(n, self_working_ptr, n, ipiv.data<int>(), &wkopt, lwork, &info);

lwork = static_cast<int>(wkopt);
work = at::empty({lwork}, self.type());

// now to compute the actual inverse
lapackGetri<scalar_t>(n, self_working_ptr, n, ipiv.data<int>(), work.data<scalar_t>(), lwork, &info);
infos[i] = info;
if (info != 0) {
return;
}
}
}

GENERATE_LINALG_HELPER_1_ARGS(inverse, self, cpu)

Tensor inverse(const Tensor &self) {
if (self.size(-1) == 0) {
return at::empty_like(self);
}
if (self.dim() == 2) {
return at::_getri_single(self);
}
inverseCheckInputs(self);
return at::_inverse_helper(self);
}

Tensor& inverse_out(Tensor &result, const Tensor &self) {
if (self.size(-1) == 0) {
return result.resize_as_(self);
}
result.copy_(native::inverse(self));
return result;
}

}} // namespace at::native
125 changes: 0 additions & 125 deletions aten/src/ATen/native/Gesv.cpp

This file was deleted.

29 changes: 0 additions & 29 deletions aten/src/ATen/native/Gesv.h

This file was deleted.

Loading