Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2671,7 +2671,7 @@
default: U
]]
[[
name: _th_potrs
name: _th_potrs_single
cname: potrs
types:
- Float
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ _(aten, _floor) \
_(aten, _fused_dropout) \
_(aten, _ger) \
_(aten, _gesv_helper) \
_(aten, _gesv_single) \
_(aten, _getri_single) \
_(aten, _indexCopy) \
_(aten, _indices) \
_(aten, _inverse_helper) \
Expand All @@ -103,6 +101,7 @@ _(aten, _pack_padded_sequence_backward) \
_(aten, _pad_packed_sequence) \
_(aten, _pdist_backward) \
_(aten, _pdist_forward) \
_(aten, _potrs_helper) \
_(aten, _prod) \
_(aten, _prodall) \
_(aten, _range) \
Expand Down
128 changes: 103 additions & 25 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "ATen/ATen.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/Dispatch.h"
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"

#include "ATen/native/LinearAlgebraUtils.h"
Expand All @@ -16,14 +15,18 @@
#ifdef USE_LAPACK

// gesv
extern "C" void dgesv_(int* n, int* nrhs, double* a, int* lda, int *ipiv, double* b, int* ldb, int* info);
extern "C" void sgesv_(int* n, int* nrhs, float* a, int* lda, int* ipiv, float* b, int* ldb, int* info);
extern "C" void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
extern "C" void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);

// inverse
extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);
extern "C" void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info);
extern "C" void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);

// potrs
extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
#endif

namespace at {
Expand All @@ -32,12 +35,12 @@ namespace native {
// Define the per-batch functions to be used in the main implementation of the batched
// linear algebra operations
template<class scalar_t>
void lapackGesv(int n, int nrhs, scalar_t* a, int lda, int* ipiv, scalar_t* b, int ldb, int* info) {
void lapackGesv(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info) {
AT_ERROR("gesv only takes float or double Tensors");
}

template<class scalar_t>
void lapackGetrf(int m, int n, scalar_t* a, int lda, int *ipiv, int *info) {
void lapackGetrf(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
AT_ERROR("getrf only takes float or double Tensors");
}

Expand All @@ -46,12 +49,17 @@ void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwo
AT_ERROR("getri only takes float or double Tensors");
}

template<class scalar_t>
void lapackPotrs(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
AT_ERROR("potrs only takes float or double Tensors");
}

#ifdef USE_LAPACK
template<> void lapackGesv<double>(int n, int nrhs, double* a, int lda, int* ipiv, double* b, int ldb, int* info) {
template<> void lapackGesv<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

template<> void lapackGesv<float>(int n, int nrhs, float* a, int lda, int* ipiv, float* b, int ldb, int* info) {
template<> void lapackGesv<float>(int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) {
sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

Expand All @@ -70,6 +78,14 @@ template<> void lapackGetrf<double>(int m, int n, double *a, int lda, int *ipiv,
template<> void lapackGetrf<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
sgetrf_(&m, &n, a, &lda, ipiv, info);
}

template<> void lapackPotrs<double>(char uplo, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
}

template<> void lapackPotrs<float>(char uplo, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
}
#endif

// Below of the definitions of the functions operating on a batch that are going to be dispatched
Expand Down Expand Up @@ -105,8 +121,16 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
}
}

// These utilities are specified in LinearAlgebraUtils.h
GENERATE_LINALG_HELPER_2_ARGS(gesv, self, A, cpu)
std::tuple<Tensor, Tensor> _gesv_helper_cpu(const Tensor& self, const Tensor& A) {
std::vector<int64_t> infos(batchCount(self), 0);
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
apply_gesv<scalar_t>(self_working_copy, A_working_copy, infos);
});
batchCheckErrors(infos, "gesv");
return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
}

// Supports arbitrary batch dimensions for self and A
std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
Expand All @@ -117,21 +141,8 @@ std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
return at::_th_gesv_single(self, A);
}

gesvCheckInputs(self, A);

// broadcast the batch dimensions of self and A.
IntList self_batch_sizes(self.sizes().data(), self.ndimension() - 2);
IntList A_batch_sizes(A.sizes().data(), A.ndimension() - 2);
std::vector<int64_t> expand_batch_portion = infer_size(self_batch_sizes, A_batch_sizes);

std::vector<int64_t> self_expand_size({expand_batch_portion});
self_expand_size.insert(self_expand_size.end(), { self.size(-2), self.size(-1) });

std::vector<int64_t> A_expand_size({expand_batch_portion});
A_expand_size.insert(A_expand_size.end(), { A.size(-2), A.size(-1) });

Tensor self_broadcasted = self.expand(self_expand_size);
Tensor A_broadcasted = A.expand(A_expand_size);
Tensor self_broadcasted, A_broadcasted;
std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
return at::_gesv_helper(self_broadcasted, A_broadcasted);
}

Expand Down Expand Up @@ -185,7 +196,15 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
}
}

GENERATE_LINALG_HELPER_1_ARGS(inverse, self, cpu)
Tensor _inverse_helper_cpu(const Tensor& self) {
std::vector<int64_t> infos(batchCount(self), 0);
auto self_working_copy = cloneBatchedColumnMajor(self);
AT_DISPATCH_FLOATING_TYPES(self.type(), "inverse", [&]{
apply_inverse<scalar_t>(self_working_copy, infos);
});
batchCheckErrors(infos, "inverse");
return self_working_copy;
}

Tensor inverse(const Tensor &self) {
if (self.size(-1) == 0) {
Expand All @@ -206,4 +225,63 @@ Tensor& inverse_out(Tensor &result, const Tensor &self) {
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ potrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
static void apply_potrs(Tensor& b, Tensor& A, bool upper, std::vector<int64_t>& infos) {
#ifndef USE_LAPACK
AT_ERROR("potrs: LAPACK library not found in compilation");
#endif
char uplo = upper ? 'U' : 'L';

auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);

auto batch_size = batchCount(A);
auto n = A.size(-2);
auto nrhs = b.size(-1);

for (int64_t i = 0; i < batch_size; i++) {
int info;
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
lapackPotrs<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
infos[i] = info;
if (info != 0) {
return;
}
}
}

Tensor _potrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
std::vector<int64_t> infos(batchCount(self), 0);
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
AT_DISPATCH_FLOATING_TYPES(self.type(), "potrs", [&]{
apply_potrs<scalar_t>(self_working_copy, A_working_copy, upper, infos);
});
batchCheckErrors(infos, "potrs");
return self_working_copy;
}

// Supports arbitrary batch dimensions for self and A
Tensor potrs(const Tensor& self, const Tensor& A, bool upper) {
if (self.dim() <= 2 && A.dim() <= 2) {
return at::_th_potrs_single(self, A, upper);
}

Tensor self_broadcasted, A_broadcasted;
std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
return at::_potrs_helper(self_broadcasted, A_broadcasted, upper);
}

Tensor& potrs_out(Tensor& result, const Tensor& self, const Tensor& A, bool upper) {
AT_CHECK(self.dim() == 2 && A.dim() == 2,
"torch.potrs() with the `out` keyword does not support batching. "
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
return at::_th_potrs_single_out(result, self, A, upper);
}

}} // namespace at::native
8 changes: 0 additions & 8 deletions aten/src/ATen/native/LegacyDefinitions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,14 +491,6 @@ Tensor cholesky(const Tensor & self, bool upper) {
return at::_th_potrf(self, upper);
}

Tensor & potrs_out(Tensor & result, const Tensor & self, const Tensor & input2, bool upper) {
return at::_th_potrs_out(result, self, input2, upper);
}

Tensor potrs(const Tensor & self, const Tensor & input2, bool upper) {
return at::_th_potrs(self, input2, upper);
}

Tensor & potri_out(Tensor & result, const Tensor & self, bool upper) {
return at::_th_potri_out(result, self, upper);
}
Expand Down
47 changes: 22 additions & 25 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ATen/ATen.h"
#include "ATen/ExpandUtils.h"
#include <limits>

namespace at { namespace native {
Expand Down Expand Up @@ -52,8 +53,8 @@ static inline double _get_epsilon(const ScalarType& sc_type) {
}
}

// Validates input shapes for gesv
static inline void gesvCheckInputs(const Tensor& self, const Tensor& A) {
// Validates input shapes for linear solve methods (gesv, potrs)
static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
AT_CHECK(A.size(-1) == A.size(-2),
"A must be batches of square matrices, "
"but they are ", A.size(-1), " by ", A.size(-2), " matrices");
Expand Down Expand Up @@ -87,34 +88,30 @@ static inline void batchCheckErrors(std::vector<int64_t>& infos, const char* nam
}
}

#define GENERATE_LINALG_HELPER_1_ARGS(NAME, ARG, BACKEND) \
Tensor _##NAME##_helper_##BACKEND(const Tensor& ARG) { \
std::vector<int64_t> infos(batchCount(ARG), 0); \
auto ARG##_working_copy = cloneBatchedColumnMajor(ARG); \
AT_DISPATCH_FLOATING_TYPES(ARG.type(), #NAME, [&]{ \
apply_##NAME<scalar_t>(ARG##_working_copy, infos); \
}); \
batchCheckErrors(infos, #NAME); \
return ARG##_working_copy; \
}

#define GENERATE_LINALG_HELPER_2_ARGS(NAME, ARG1, ARG2, BACKEND) \
std::tuple<Tensor, Tensor> _##NAME##_helper_##BACKEND(const Tensor& ARG1, const Tensor& ARG2) { \
std::vector<int64_t> infos(batchCount(ARG1), 0); \
auto ARG1##_working_copy = cloneBatchedColumnMajor(ARG1); \
auto ARG2##_working_copy = cloneBatchedColumnMajor(ARG2); \
AT_DISPATCH_FLOATING_TYPES(ARG1.type(), #NAME, [&]{ \
apply_##NAME<scalar_t>(ARG1##_working_copy, ARG2##_working_copy, infos); \
}); \
batchCheckErrors(infos, #NAME); \
return std::tuple<Tensor, Tensor>(ARG1##_working_copy, ARG2##_working_copy); \
}

// Checks if all the Tensors in a TensorList are of the same dimensions
static inline void checkAllSameDim(TensorList tensors, int64_t dim) {
for (auto &t : tensors) {
AT_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
}
}

static inline std::tuple<Tensor,Tensor> _linear_solve_broadcast_args(const Tensor& arg1, const Tensor& arg2) {
linearSolveCheckInputs(arg1, arg2);

// broadcast the batch dimensions of arg1 and arg2.
IntList arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
IntList arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);

std::vector<int64_t> arg1_expand_size({expand_batch_portion});
arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });

std::vector<int64_t> arg2_expand_size({expand_batch_portion});
arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });

Tensor arg1_broadcasted = arg1.expand(arg1_expand_size);
Tensor arg2_broadcasted = arg2.expand(arg2_expand_size);
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
}

}} // namespace at::native
Loading