|
| 1 | +#include "ATen/Context.h" |
| 2 | +#include "ATen/Dispatch.h" |
| 3 | +#include "ATen/NativeFunctions.h" |
| 4 | +#include "ATen/PinnedMemoryAllocator.h" |
| 5 | +#include "ATen/cuda/CUDAApplyUtils.cuh" |
| 6 | + |
| 7 | +#include "ATen/native/LinearAlgebraUtils.h" |
| 8 | +#include "ATen/native/Gesv.h" |
| 9 | + |
| 10 | +#include "THC.h" // for USE_MAGMA |
| 11 | + |
| 12 | +#ifdef USE_MAGMA |
| 13 | +#include <magma.h> |
| 14 | +#include <magma_types.h> |
| 15 | +#endif |
| 16 | + |
| 17 | +namespace at { |
| 18 | +namespace native { |
| 19 | + |
| 20 | +#ifdef USE_MAGMA |
| 21 | +template<class scalar_t> |
| 22 | +void magmaGesvBatched( |
| 23 | + magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda, |
| 24 | + magma_int_t** dipiv_array, scalar_t** dB_array, magma_int_t lddb, |
| 25 | + magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) { |
| 26 | + AT_ERROR("gesv only takes float or double Tensors"); |
| 27 | +} |
| 28 | + |
| 29 | +template<> |
| 30 | +void magmaGesvBatched<float>( |
| 31 | + magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda, |
| 32 | + magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb, |
| 33 | + magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) { |
| 34 | + magma_sgesv_batched( |
| 35 | + n, nrhs, dA_array, ldda, dipiv_array, |
| 36 | + dB_array, lddb, dinfo_array, batch_count, queue); |
| 37 | +} |
| 38 | + |
| 39 | +template<> |
| 40 | +void magmaGesvBatched<double>( |
| 41 | + magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda, |
| 42 | + magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb, |
| 43 | + magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) { |
| 44 | + magma_dgesv_batched( |
| 45 | + n, nrhs, dA_array, ldda, dipiv_array, |
| 46 | + dB_array, lddb, dinfo_array, batch_count, queue); |
| 47 | +} |
| 48 | + |
| 49 | +static magma_queue_t createMagmaQueue(const Tensor& tensor) { |
| 50 | + auto& context = tensor.type().get_context(); |
| 51 | + magma_queue_t magma_queue; |
| 52 | + magma_queue_create_from_cuda( |
| 53 | + tensor.get_device(), |
| 54 | + context.getCurrentCUDAStream(), |
| 55 | + THCState_getCurrentBlasHandle(context.thc_state), |
| 56 | + THCState_getCurrentSparseHandle(context.thc_state), |
| 57 | + &magma_queue); |
| 58 | + return magma_queue; |
| 59 | +} |
| 60 | +#endif |
| 61 | + |
| 62 | +static inline magma_int_t magma_int_cast(int64_t value, const char* varname) { |
| 63 | + auto result = static_cast<magma_int_t>(value); |
| 64 | + if (static_cast<int64_t>(result) != value) { |
| 65 | + AT_ERROR("magma: The value of %s (%lld) is too large to fit into a magma_int_t (%llu bytes)", |
| 66 | + varname, (long long)value, sizeof(magma_int_t)); |
| 67 | + } |
| 68 | + return result; |
| 69 | +} |
| 70 | + |
| 71 | +// Creates an array of size elements of type T, backed by pinned memory |
| 72 | +// wrapped in a Storage |
| 73 | +template<class T> |
| 74 | +static inline std::unique_ptr<Storage> pin_memory(int64_t size, Tensor dummy) { |
| 75 | + int64_t adjusted_size = size * sizeof(T); |
| 76 | + auto allocator = std::unique_ptr<Allocator>(new PinnedMemoryAllocator()); |
| 77 | + auto& backend = dummy.type().toBackend(kCPU).toScalarType(kByte); |
| 78 | + return backend.storageWithAllocator(adjusted_size, std::move(allocator)); |
| 79 | +} |
| 80 | + |
| 81 | +#define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \ |
| 82 | + auto storage_##name = pin_memory<type>(size, dummy_tensor); \ |
| 83 | + name = reinterpret_cast<type*>(storage_##name->data()); |
| 84 | + |
| 85 | +template <typename scalar_t> |
| 86 | +static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) { |
| 87 | +#ifndef USE_MAGMA |
| 88 | +AT_ERROR("gesv: MAGMA library not found in " |
| 89 | + "compilation. Please rebuild with MAGMA."); |
| 90 | +#else |
| 91 | + auto A_data = A.data<scalar_t>(); |
| 92 | + auto b_data = b.data<scalar_t>(); |
| 93 | + auto A_mat_stride = matrixStride(A); |
| 94 | + auto b_mat_stride = matrixStride(b); |
| 95 | + |
| 96 | + magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount"); |
| 97 | + magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)"); |
| 98 | + magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); |
| 99 | + |
| 100 | + magma_int_t* info_array; |
| 101 | + magma_int_t* ipiv_data; |
| 102 | + magma_int_t** ipiv_array; |
| 103 | + scalar_t** A_array; |
| 104 | + scalar_t** b_array; |
| 105 | + |
| 106 | + ALLOCATE_ARRAY(info_array, magma_int_t, batch_size, b); |
| 107 | + ALLOCATE_ARRAY(ipiv_data, magma_int_t, batch_size * n, b); |
| 108 | + ALLOCATE_ARRAY(ipiv_array, magma_int_t*, batch_size, b); |
| 109 | + ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b); |
| 110 | + ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b); |
| 111 | + |
| 112 | + // Set up the created arrays |
| 113 | + for (int64_t i = 0; i < batch_size; i++) { |
| 114 | + A_array[i] = &A_data[i * A_mat_stride]; |
| 115 | + b_array[i] = &b_data[i * b_mat_stride]; |
| 116 | + ipiv_array[i] = &ipiv_data[i * n]; |
| 117 | + } |
| 118 | + |
| 119 | + magmaGesvBatched<scalar_t>( |
| 120 | + n, nrhs, A_array, n, ipiv_array, b_array, n, |
| 121 | + info_array, batch_size, createMagmaQueue(b)); |
| 122 | + |
| 123 | + for (int64_t i = 0; i < batch_size; i++) { |
| 124 | + infos[i] = info_array[i]; |
| 125 | + } |
| 126 | +#endif |
| 127 | +} |
| 128 | + |
| 129 | +std::tuple<Tensor,Tensor> _gesv_helper_cuda(const Tensor& self, const Tensor& A) { |
| 130 | + std::vector<int64_t> infos(batchCount(A), 0); |
| 131 | + auto A_working_copy = cloneBatchedColumnMajor(A); |
| 132 | + auto b_working_copy = cloneBatchedColumnMajor(self); |
| 133 | + AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{ |
| 134 | + applyGesv<scalar_t>(b_working_copy, A_working_copy, infos); |
| 135 | + }); |
| 136 | + checkErrors(infos); |
| 137 | + return std::tuple<Tensor,Tensor>(b_working_copy, A_working_copy); |
| 138 | +} |
| 139 | + |
| 140 | +}} // namespace at::native |
| 141 | + |
| 142 | +#undef ALLOCATE_ARRAY |
0 commit comments