Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1090e20
Moved Gesv to ATen
animesht Jul 24, 2018
2144c29
Removed TH code, added CUDA support
animesht Jul 24, 2018
f64b620
Fixed CUDA support, removed gesv from THC, fixed autograd+tests
animesht Jul 25, 2018
e9cd831
Used options().dtype(kInt)
animesht Jul 26, 2018
333b230
addressed comments - simplified gesv_single
animesht Jul 26, 2018
152f0cb
Fixed autograd
animesht Jul 26, 2018
e20d906
Addressed comments
animesht Jul 27, 2018
83ac5f0
Fixed copying for output Tensors
animesht Jul 27, 2018
b0f5845
Use contiguous Tensor while copying B
animesht Jul 27, 2018
578193e
Fixed view/contiguous stuff again...
animesht Jul 27, 2018
3eb3450
Make A contiguous before copy
animesht Jul 28, 2018
62ffda1
Handle Tensor reuse
animesht Jul 29, 2018
5560826
Finally fixed the last issue with torch.gesv(b, A, out=(b,A))
animesht Jul 30, 2018
ff342bb
Addressed comments
animesht Jul 31, 2018
a4ae861
Use data_ptr() instead of address for comparison
animesht Jul 31, 2018
8c88229
Revert to address check, add missing case
animesht Aug 2, 2018
a43cc74
Use at::optional, out-of-place transpose
animesht Aug 3, 2018
b68c6f1
Fixed re-entrant safeness
animesht Aug 4, 2018
fdd2bd2
Revert error check, simplify temps, try to fix MSVC/windows build
animesht Aug 6, 2018
e441662
Merge branch 'master' of https://github.com/animesht/pytorch
animesht Aug 6, 2018
2a3644b
Fixes for MSVC and re-entrant safeness
animesht Aug 7, 2018
91de09f
Simplified logic, added notes
animesht Aug 8, 2018
bfb323c
Addressed comments
animesht Aug 9, 2018
854647e
Addressed comments
animesht Aug 9, 2018
36dd265
Fix error message
animesht Aug 10, 2018
419cdc8
renamed vars, pass infos by reference
animesht Aug 13, 2018
72c6afc
Fix formatting issues with %lld format specifier
animesht Aug 13, 2018
5ea2e77
Revert pass by reference
animesht Aug 13, 2018
0838375
Refactored
animesht Aug 17, 2018
f0e0739
Fixed typo
animesht Aug 17, 2018
c1bd0a0
Fix indent
animesht Aug 17, 2018
a01c96e
fix jenkins
animesht Aug 18, 2018
dac0811
fix jenkins
animesht Aug 18, 2018
c980e34
Addressed comments
animesht Aug 23, 2018
1bb8d22
Addressed comments
animesht Aug 23, 2018
5341fee
fixed typo
animesht Aug 23, 2018
5973c4e
Addressed comments
animesht Aug 28, 2018
6c573ec
Added tests
animesht Sep 6, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2938,27 +2938,6 @@
- THTensor* tensor1
- THTensor* tensor2
]]
[[
name: _gesv_single
cname: gesv
types:
- Float
- Double
backends:
- CPU
- CUDA
variants:
- method
- function
return: argument 0,1
arguments:
- arg: THTensor* solution
output: True
- arg: THTensor* lu
output: True
- THTensor* self
- THTensor* A
]]
[[
name: gels
types:
Expand Down
60 changes: 48 additions & 12 deletions aten/src/ATen/native/Gesv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ template<> void lapackGesv<double>(
template <typename scalar_t>
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
#ifndef USE_LAPACK
AT_ERROR("gesv: LAPACK library not found in compilation");
AT_ERROR("gesv : Lapack library not found in compile time");
#endif
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
Expand All @@ -57,7 +57,7 @@ static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
auto n = A.size(-2);
auto nrhs = b.size(-1);

auto ipiv = at::empty({n}, b.type().toScalarType(kInt));
auto ipiv = at::empty({n}, b.options().dtype(kInt));

for (int64_t i = 0; i < batch_size; i++) {
int info;
Expand All @@ -72,6 +72,37 @@ static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
}
}

std::tuple<Tensor&,Tensor&> _gesv_single_out_cpu(
Tensor& sol, Tensor& lu,
const Tensor& self, const Tensor& A) {
#ifndef USE_LAPACK
AT_ERROR("gesv : Lapack library not found in compile time");
#endif
int info = 0;
Tensor temp_sol;
Tensor temp_lu;
auto& A_tensor = prepareTensorsForLapack(A, lu, temp_lu);
auto& b_tensor = prepareTensorsForLapack(self, sol, temp_sol);

AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
const int64_t n = sol.size(0);
const int64_t nrhs = sol.size(1);
auto A_ptr = A_tensor.data<scalar_t>();
auto b_ptr = b_tensor.data<scalar_t>();
auto ipiv = at::empty({n}, sol.options().dtype(kInt));
lapackGesv<scalar_t>(n, nrhs, A_ptr, n, ipiv.data<int>(), b_ptr, n, &info);
});
checkErrors({info});

if (temp_sol.defined()) {
sol.copy_(temp_sol);
}
if (temp_lu.defined()) {
lu.copy_(temp_lu);
}
return std::tuple<Tensor&, Tensor&>(sol, lu);
}

std::tuple<Tensor,Tensor> _gesv_helper_cpu(const Tensor& self, const Tensor& A) {
std::vector<int64_t> infos(batchCount(A), 0);
auto A_working_copy = cloneBatchedColumnMajor(A);
Expand All @@ -83,17 +114,21 @@ std::tuple<Tensor,Tensor> _gesv_helper_cpu(const Tensor& self, const Tensor& A)
return std::tuple<Tensor,Tensor>(b_working_copy, A_working_copy);
}

std::tuple<Tensor,Tensor> _gesv_single(const Tensor& self, const Tensor& A) {
auto sol = self.type().tensor();
auto lu = self.type().tensor();
return self.type()._gesv_single_out(sol, lu, self, A);
}

// Supports arbitrary batch dimensions for self and A
std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
if (self.dim() <= 2 && A.dim() <= 2) {
// TODO: #7102: It's not necessary to have gesv (single) bindings for both
// TH and ATen. We should remove the TH gesv bindings, especially
// since the lapackGesv function is already in ATen.
bool batched = !(self.dim() <= 2 && A.dim() <= 2);
checkInputs(self, A, batched);

if (!batched) {
return at::_gesv_single(self, A);
}

checkInputs(self, A);

// broadcast the batch dimensions of self and A.
IntList self_batch_sizes(self.sizes().data(), self.ndimension() - 2);
IntList A_batch_sizes(A.sizes().data(), A.ndimension() - 2);
Expand All @@ -114,13 +149,14 @@ std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
}

std::tuple<Tensor&,Tensor&> gesv_out(
Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) {
Tensor& sol, Tensor& lu, const Tensor& self, const Tensor& A) {
if (self.dim() > 2 || A.dim() > 2) {
AT_ERROR("torch.gesv() with the `out` keyword does not support batching. "
"b.dim() (%lld) and A.dim() (%lld) must both be 2.",
(long long)self.dim(), (long long)A.dim());
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(),
") must both be 2.");
}
return at::_gesv_single_out(solution, lu, self, A);

return self.type()._gesv_single_out(sol, lu, self, A);
}

}} // namespace at::native
94 changes: 80 additions & 14 deletions aten/src/ATen/native/Gesv.h
Original file line number Diff line number Diff line change
@@ -1,30 +1,96 @@
#include <utility>
#include "ATen/ATen.h"

namespace at { namespace native {

static inline void checkInputs(const Tensor& self, const Tensor& A) {
if (A.size(-1) != A.size(-2)) {
AT_ERROR("A must be batches of square matrices, "
"but they are %lld by %lld matrices",
(long long)A.size(-1), (long long)A.size(-2));
static inline bool isTransposeContiguous(Tensor& self) {
return self.dim() == 2 &&
self.stride(0) == 1 &&
self.stride(1) == self.size(0);
}

/* gesv takes (self, A) and returns (sol, lu).
* (i) output tensors (sol, lu) may be same as input tensors (self, A)
* (ii) for 2D matrices, .t_() represents their column-major format
*
* Before passing pointers to Lapack, we need to ensure that these pointers
* represent Fortran-contiguous tensors in column-major format

This comment was marked as off-topic.

*
* Cases:

This comment was marked as off-topic.

* 1) `out` has correct shape but elements do not form a contiguous
* chunk of memory. Since shape is correct, we don't resize_ it. Instead, we
* clone the input tensor into a buffer, use the buffer for Lapack and finally
* copy the buffer to the output tensor.
*
* 2) out.t() is contiguous:
* (i) &in == &out: use out.data() as is. Do nothing

This comment was marked as off-topic.

This comment was marked as off-topic.

* (ii) &in != &out: copy in.t() to out.t()
* 3) out.t() is not contiguous:
* - resize_ should fix contiguity/size issues
* (i) &in == &out: copy in.t().clone() to out (same tensor)
* (ii) &in != &out: copy in.t() to out
*/
static inline Tensor& prepareTensorsForLapack(
const Tensor& in, Tensor& out, Tensor& temp) {
int64_t x = in.size(0);
int64_t y = (in.dim() == 1) ? 1 : in.size(1);
bool out_tc = isTransposeContiguous(out);
bool out_correct_shape =
out.dim() == 2 && out.size(0) == x && out.size(1) == y;

// view potential 1D `in` as 2D
auto in_t = in.view({x, y}).t_();

if (!out_tc && !out.is_contiguous() && out_correct_shape) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

temp = in_t.clone().t_();
} else if (out_tc && &in != &out) {
out.t().resize_({y, x}).copy_(in_t);
} else if (!out_tc) {
out.resize_({y, x});
if (&in == &out) {
out.copy_(in_t.clone()).t_();
} else {
out.copy_(in_t).t_();
}
}
if (A.size(-1) != self.size(-2)) {
AT_ERROR("Incompatible matrix sizes for matmul: each A "
"matrix is %llu by %lld but each b matrix is %lld by %lld.",
(long long)A.size(-1), (long long)A.size(-1),
(long long)self.size(-2), (long long)self.size(-1));
// return ref to usable tensor for Lapack
return temp.defined() ? temp : out;
}

static inline void checkInputs(const Tensor& self, const Tensor& A, bool batched) {
if (batched) {
if (A.size(-1) != A.size(-2)) {
AT_ERROR("A must be batches of square matrices, "
"but they are ", A.size(-1), " by ", A.size(-2), " matrices");
} else if (A.size(-1) != self.size(-2)) {
AT_ERROR("incompatible matrix sizes for matmul: each a "
"matrix is ", A.size(-1), " by ", A.size(-1),
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
}
} else {
if (A.dim() != 2) {
AT_ERROR("A should have 2 dimensions, but has ", A.dim());
} else if (self.dim() != 1 && self.dim() != 2) {
AT_ERROR("B should have 1 or 2 dimensions, but has ", self.dim());
} else if (A.size(0) != A.size(1)) {
AT_ERROR("A must be a square matrix, but is ",
A.size(0), " by ", A.size(1));
} else if (A.size(0) != self.size(0)) {
AT_ERROR("A,B size incompatible - A has ", A.size(0),
" rows, B has ", self.size(0), " cols");
}
}
}

static inline void checkErrors(std::vector<int64_t> infos) {
for (size_t i = 0; i < infos.size(); i++) {
auto info = infos[i];
if (info < 0) {
AT_ERROR("gesv: For batch %lld: Argument %lld has illegal value",
(long long)i, -info);
AT_ERROR("gesv: For batch ", i, ": Argument ",
-info, " has illegal value");
} else if (info > 0) {
AT_ERROR("gesv: For batch %lld: U(%lld,%lld) is zero, singular U.",
(long long)i, info, info);
AT_ERROR("gesv: For batch ", i, ": U(", info, ",", info,
") is zero, singular U.");
}
}
}
Expand Down
54 changes: 54 additions & 0 deletions aten/src/ATen/native/cuda/Gesv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,28 @@ namespace at {
namespace native {

#ifdef USE_MAGMA

template<class scalar_t>
void magmaGesv(
int64_t n, int64_t nrhs, scalar_t* A_data, int64_t lda,
int* ipiv, scalar_t* B_data, int64_t ldb, int* info) {
AT_ERROR("magma: gesv only takes float or double Tensors");
}

template<>
void magmaGesv<float>(
int64_t n, int64_t nrhs, float* A_data, int64_t lda,
int* ipiv, float* B_data, int64_t ldb, int* info) {
magma_sgesv_gpu(n, nrhs, A_data, lda, ipiv, B_data, ldb, info);
}

template<>
void magmaGesv<double>(
int64_t n, int64_t nrhs, double* A_data, int64_t lda,
int* ipiv, double* B_data, int64_t ldb, int* info) {
magma_dgesv_gpu(n, nrhs, A_data, lda, ipiv, B_data, ldb, info);
}

template<class scalar_t>
void magmaGesvBatched(
magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda,
Expand Down Expand Up @@ -138,6 +160,38 @@ std::tuple<Tensor,Tensor> _gesv_helper_cuda(const Tensor& self, const Tensor& A)
return std::tuple<Tensor,Tensor>(b_working_copy, A_working_copy);
}

std::tuple<Tensor&,Tensor&> _gesv_single_out_cuda(Tensor& sol, Tensor& lu,
const Tensor& self, const Tensor& A) {
#ifndef USE_MAGMA
AT_ERROR("gesv: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
int info = 0;
int* ipiv;
Tensor temp_sol;
Tensor temp_lu;
auto& A_tensor = prepareTensorsForLapack(A, lu, temp_lu);
auto& b_tensor = prepareTensorsForLapack(self, sol, temp_sol);

AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
const int64_t n = sol.size(0);
const int64_t nrhs = sol.size(1);
auto A_ptr = A_tensor.data<scalar_t>();
auto b_ptr = b_tensor.data<scalar_t>();
ALLOCATE_ARRAY(ipiv, int, n, sol);
magmaGesv<scalar_t>(n, nrhs, A_ptr, n, ipiv, b_ptr, n, &info);
});
checkErrors({info});

if (temp_sol.defined()) {
sol.copy_(temp_sol);
}
if (temp_lu.defined()) {
lu.copy_(temp_lu);
}
return std::tuple<Tensor&,Tensor&>(sol, lu);
#endif
}
}} // namespace at::native

#undef ALLOCATE_ARRAY
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,14 @@
CPU: _gesv_helper_cpu
CUDA: _gesv_helper_cuda

- func: _gesv_single(Tensor self, Tensor A) -> (Tensor, Tensor)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


- func: _gesv_single_out(Tensor solution, Tensor lu, Tensor self, Tensor A) -> (Tensor, Tensor)
variants: function
dispatch:
CPU: _gesv_single_out_cpu
CUDA: _gesv_single_out_cuda

- func: group_norm(Tensor input, int64_t num_groups, Tensor? weight={}, Tensor? bias={}, double eps=1e-5, bool cudnn_enabled=True) -> Tensor
variants: function

Expand Down
17 changes: 0 additions & 17 deletions aten/src/TH/generic/THLapack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#else


TH_EXTERNC void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
TH_EXTERNC void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
TH_EXTERNC void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
TH_EXTERNC void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info);
Expand Down Expand Up @@ -37,21 +35,6 @@ TH_EXTERNC void spstrf_(char *uplo, int *n, float *a, int *lda, int *piv, int *r
TH_EXTERNC void dpstrf_(char *uplo, int *n, double *a, int *lda, int *piv, int *rank, double *tol, double *work, int *info);


/* Compute the solution to a real system of linear equations A * X = B */
void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info)
{
#ifdef USE_LAPACK
#if defined(TH_REAL_IS_DOUBLE)
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
#else
sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
#endif
#else
THError("gesv : Lapack library not found in compile time\n");
#endif
return;
}

/* Solve a triangular system of the form A * X = B or A^T * X = B */
void THLapack_(trtrs)(char uplo, char trans, char diag, int n, int nrhs, real *a, int lda, real *b, int ldb, int* info)
{
Expand Down
2 changes: 0 additions & 2 deletions aten/src/TH/generic/THLapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#define TH_GENERIC_FILE "generic/THLapack.h"
#else

/* AX=B */
TH_API void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info);
/* Solve a triangular system of the form A * X = B or A^T * X = B */
TH_API void THLapack_(trtrs)(char uplo, char trans, char diag, int n, int nrhs, real *a, int lda, real *b, int ldb, int* info);
/* ||AX-B|| */
Expand Down
Loading