Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -3004,7 +3004,8 @@
- THTensor* tensor2
]]
[[
name: gesv
name: _gesv_single
cname: gesv
types:
- Float
- Double
Expand Down
126 changes: 126 additions & 0 deletions aten/src/ATen/native/Gesv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "ATen/ATen.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/Dispatch.h"
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"

#include "ATen/native/LinearAlgebraUtils.h"
#include "ATen/native/Gesv.h"

#include "TH.h" // for USE_LAPACK

#include <vector>

#ifdef USE_LAPACK
extern "C" void dgesv_(
int* n, int* nrhs, double* a, int* lda,
int *ipiv, double* b, int* ldb, int* info);
extern "C" void sgesv_(
int* n, int* nrhs, float* a, int* lda,
int* ipiv, float* b, int* ldb, int* info);
#endif

namespace at { namespace native {

template<class scalar_t>
void lapackGesv(
int n, int nrhs, scalar_t* a, int lda, int* ipiv,
scalar_t* b, int ldb, int* info) {
AT_ERROR("gesv only takes float or double Tensors");
}

#ifdef USE_LAPACK
template<> void lapackGesv<float>(
int n, int nrhs, float* a, int lda, int* ipiv,
float* b, int ldb, int* info) {
sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

template<> void lapackGesv<double>(
int n, int nrhs, double* a, int lda, int* ipiv,
double* b, int ldb, int* info) {
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}
#endif

template <typename scalar_t>
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
#ifndef USE_LAPACK
AT_ERROR("gesv: LAPACK library not found in compilation");
#endif
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);

auto batch_size = batchCount(A);
auto n = A.size(-2);
auto nrhs = b.size(-1);

auto ipiv = b.type().toScalarType(kInt).tensor(n);

for (int64_t i = 0; i < batch_size; i++) {
int info;
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
lapackGesv<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(),
b_working_ptr, n, &info);
infos[i] = info;
if (info != 0) {
return;
}
}
}

std::tuple<Tensor,Tensor> _gesv_helper_cpu(const Tensor& self, const Tensor& A) {
std::vector<int64_t> infos(batchCount(A), 0);
auto A_working_copy = cloneBatchedColumnMajor(A);
auto b_working_copy = cloneBatchedColumnMajor(self);
AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
applyGesv<scalar_t>(b_working_copy, A_working_copy, infos);
});
checkErrors(infos);
return std::tuple<Tensor,Tensor>(b_working_copy, A_working_copy);
}

// Supports arbitrary batch dimensions for self and A
std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
if (self.dim() <= 2 && A.dim() <= 2) {
// TODO: #7102: It's not necessary to have gesv (single) bindings for both
// TH and ATen. We should remove the TH gesv bindings, especially
// since the lapackGesv function is already in ATen.
return at::_gesv_single(self, A);
}

checkInputs(self, A);

// broadcast the batch dimensions of self and A.
IntList self_batch_sizes(self.sizes().data(), self.ndimension() - 2);
IntList A_batch_sizes(A.sizes().data(), A.ndimension() - 2);
std::vector<int64_t> expand_batch_portion =
infer_size(self_batch_sizes, A_batch_sizes);

std::vector<int64_t> self_expand_size({expand_batch_portion});
self_expand_size.insert(self_expand_size.end(),
{ self.size(-2), self.size(-1) });

std::vector<int64_t> A_expand_size({expand_batch_portion});
A_expand_size.insert(A_expand_size.end(),
{ A.size(-2), A.size(-1) });

Tensor self_broadcasted = self.expand(self_expand_size);
Tensor A_broadcasted = A.expand(A_expand_size);
return self.type()._gesv_helper(self_broadcasted, A_broadcasted);
}

std::tuple<Tensor&,Tensor&> gesv_out(
Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) {
if (self.dim() > 2 || A.dim() > 2) {
AT_ERROR("torch.gesv() with the `out` keyword does not support batching. "
"b.dim() (%lld) and A.dim() (%lld) must both be 2.",
(long long)self.dim(), (long long)A.dim());
}
return at::_gesv_single_out(solution, lu, self, A);
}

}} // namespace at::native
32 changes: 32 additions & 0 deletions aten/src/ATen/native/Gesv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "ATen/ATen.h"

namespace at { namespace native {

static inline void checkInputs(const Tensor& self, const Tensor& A) {
if (A.size(-1) != A.size(-2)) {
AT_ERROR("A must be batches of square matrices, "
"but they are %lld by %lld matrices",
(long long)A.size(-1), (long long)A.size(-2));
}
if (A.size(-1) != self.size(-2)) {
AT_ERROR("Incompatible matrix sizes for matmul: each A "
"matrix is %llu by %lld but each b matrix is %lld by %lld.",
(long long)A.size(-1), (long long)A.size(-1),
(long long)self.size(-2), (long long)self.size(-1));
}
}

static inline void checkErrors(std::vector<int64_t> infos) {
for (size_t i = 0; i < infos.size(); i++) {
auto info = infos[i];
if (info < 0) {
AT_ERROR("gesv: For batch %lld: Argument %lld has illegal value",
(long long)i, -info);
} else if (info > 0) {
AT_ERROR("gesv: For batch %lld: U(%lld,%lld) is zero, singular U.",
(long long)i, info, info);
}
}
}

}} // namespace at::native
42 changes: 42 additions & 0 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "ATen/ATen.h"

namespace at { namespace native {

/*
* Clones a Tensor so that the following conditions hold:
* If we think of a Tensor of having size (B, M, N), where B is any number
* of batch dimensions, then:
* - Each (M, N) matrix is in column major form
* - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
* Then when laid out in memory, the M by N matrix starting at
* P.data_ptr()[b * M * N] is of the same corresponding batch as the M' by N'
* matrix starting at Q.data_ptr()[b * M' * N'].
*/
static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
// If src is already in batched column major format, then
// this will be efficient (no reordering of the data will occur)
// because the first transpose will make the tensor contiguous,
// and cloning a contiguous tensor is fast.
auto result = src.transpose(-2, -1).clone();
result.transpose_(-2, -1);
return result;
}

/*
* Given batches of matrices with arbitrary batch dim,
* computes the number of batches.
*/
static inline int64_t batchCount(const Tensor& batched_matrices) {
int64_t result = 1;
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
result *= batched_matrices.size(i);
}
return result;
}

// Computes the number of elements of a matrix in a batched matrix tensor
static inline int64_t matrixStride(const Tensor& batched_matrices) {
return batched_matrices.size(-1) * batched_matrices.size(-2);
}

}} // namespace at::native
142 changes: 142 additions & 0 deletions aten/src/ATen/native/cuda/Gesv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#include "ATen/Context.h"
#include "ATen/Dispatch.h"
#include "ATen/NativeFunctions.h"
#include "ATen/PinnedMemoryAllocator.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"

#include "ATen/native/LinearAlgebraUtils.h"
#include "ATen/native/Gesv.h"

#include "THC.h" // for USE_MAGMA

#ifdef USE_MAGMA
#include <magma.h>
#include <magma_types.h>
#endif

namespace at {
namespace native {

#ifdef USE_MAGMA
template<class scalar_t>
void magmaGesvBatched(
magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, scalar_t** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) {
AT_ERROR("gesv only takes float or double Tensors");
}

template<>
void magmaGesvBatched<float>(
magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) {
magma_sgesv_batched(
n, nrhs, dA_array, ldda, dipiv_array,
dB_array, lddb, dinfo_array, batch_count, queue);
}

template<>
void magmaGesvBatched<double>(
magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) {
magma_dgesv_batched(
n, nrhs, dA_array, ldda, dipiv_array,
dB_array, lddb, dinfo_array, batch_count, queue);
}

static magma_queue_t createMagmaQueue(const Tensor& tensor) {
auto& context = tensor.type().get_context();
magma_queue_t magma_queue;
magma_queue_create_from_cuda(
tensor.get_device(),
context.getCurrentCUDAStream(),
THCState_getCurrentBlasHandle(context.thc_state),
THCState_getCurrentSparseHandle(context.thc_state),
&magma_queue);
return magma_queue;
}
#endif

static inline magma_int_t magma_int_cast(int64_t value, const char* varname) {
auto result = static_cast<magma_int_t>(value);
if (static_cast<int64_t>(result) != value) {
AT_ERROR("magma: The value of %s (%lld) is too large to fit into a magma_int_t (%llu bytes)",
varname, (long long)value, sizeof(magma_int_t));
}
return result;
}

// Creates an array of size elements of type T, backed by pinned memory
// wrapped in a Storage
template<class T>
static inline std::unique_ptr<Storage> pin_memory(int64_t size, Tensor dummy) {
int64_t adjusted_size = size * sizeof(T);
auto allocator = std::unique_ptr<Allocator>(new PinnedMemoryAllocator());
auto& backend = dummy.type().toBackend(kCPU).toScalarType(kByte);
return backend.storageWithAllocator(adjusted_size, std::move(allocator));
}

#define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \
auto storage_##name = pin_memory<type>(size, dummy_tensor); \
name = reinterpret_cast<type*>(storage_##name->data());

template <typename scalar_t>
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
#ifndef USE_MAGMA
AT_ERROR("gesv: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);

magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");

magma_int_t* info_array;
magma_int_t* ipiv_data;
magma_int_t** ipiv_array;
scalar_t** A_array;
scalar_t** b_array;

ALLOCATE_ARRAY(info_array, magma_int_t, batch_size, b);
ALLOCATE_ARRAY(ipiv_data, magma_int_t, batch_size * n, b);
ALLOCATE_ARRAY(ipiv_array, magma_int_t*, batch_size, b);
ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);

// Set up the created arrays
for (int64_t i = 0; i < batch_size; i++) {
A_array[i] = &A_data[i * A_mat_stride];
b_array[i] = &b_data[i * b_mat_stride];
ipiv_array[i] = &ipiv_data[i * n];
}

magmaGesvBatched<scalar_t>(
n, nrhs, A_array, n, ipiv_array, b_array, n,
info_array, batch_size, createMagmaQueue(b));

for (int64_t i = 0; i < batch_size; i++) {
infos[i] = info_array[i];
}
#endif
}

std::tuple<Tensor,Tensor> _gesv_helper_cuda(const Tensor& self, const Tensor& A) {
std::vector<int64_t> infos(batchCount(A), 0);
auto A_working_copy = cloneBatchedColumnMajor(A);
auto b_working_copy = cloneBatchedColumnMajor(self);
AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
applyGesv<scalar_t>(b_working_copy, A_working_copy, infos);
});
checkErrors(infos);
return std::tuple<Tensor,Tensor>(b_working_copy, A_working_copy);
}

}} // namespace at::native

#undef ALLOCATE_ARRAY
11 changes: 11 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,17 @@
- func: ger_out(Tensor result, Tensor self, Tensor vec2) -> Tensor
variants: function

- func: gesv(Tensor self, Tensor A) -> (Tensor, Tensor)

- func: gesv_out(Tensor solution, Tensor lu, Tensor self, Tensor A) -> (Tensor, Tensor)
variants: function

# gesv handles broadcasting of arbitrary batch dims while _gesv_helper does not.
- func: _gesv_helper(Tensor self, Tensor A) -> (Tensor, Tensor)
dispatch:
CPU: _gesv_helper_cpu
CUDA: _gesv_helper_cuda

- func: group_norm(Tensor input, int64_t num_groups, Tensor? weight={}, Tensor? bias={}, double eps=1e-5, bool cudnn_enabled=True) -> Tensor
variants: function

Expand Down
4 changes: 4 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2869,6 +2869,10 @@ class dont_convert(tuple):
('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS,
'large', NO_ARGS, [skipIfNoLapack]),
('gesv', (S, S), ((S, S),), '', NO_ARGS, [skipIfNoLapack]),
('gesv', (S, S, S), ((S, S, S),), 'batched', NO_ARGS, [skipIfNoLapack]),
('gesv', (2, 3, S, S), ((2, 3, S, S),), 'batched_dims', NO_ARGS, [skipIfNoLapack]),
('gesv', (2, 2, S, S), ((1, S, S),), 'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]),
('gesv', (1, S, S), ((2, 2, S, S),), 'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]),
('fill_', (S, S, S), (1,), 'number'),
('fill_', (), (1,), 'number_scalar'),
# FIXME: we should compute the derivative w.r.t torch.tensor(1)
Expand Down
8 changes: 8 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,14 @@ def _select_broadcastable_dims(dims_full=None):
def test_det_logdet_slogdet(self):
TestTorch._test_det_logdet_slogdet(self, lambda t: t.cuda())

@unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected")
def test_gesv_batched(self):
TestTorch._test_gesv_batched(self, lambda t: t.cuda())

@unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected")
def test_gesv_batched_dims(self):
TestTorch._test_gesv_batched_dims(self, lambda t: t.cuda())

def test_view(self):
TestTorch._test_view(self, lambda t: t.cuda())

Expand Down
Loading