Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2471,11 +2471,12 @@
- THTensor* mat2
]]
[[
name: bmm
name: _th_bmm
cname: baddbmm
variants:
- method
- function
backends:
- CUDA
return: argument 0
arguments:
- arg: THTensor* result
Expand Down Expand Up @@ -2525,10 +2526,12 @@
- THTensor* batch2
]]
[[
name: baddbmm
name: _th_baddbmm
cname: baddbmm
variants:
- method
- function
backends:
- CUDA
return: argument 0
arguments:
- arg: THTensor* result
Expand All @@ -2544,22 +2547,6 @@
- THTensor* batch1
- THTensor* batch2
]]
[[
name: baddbmm_
cname: baddbmm
return: argument 0
arguments:
- THTensor* self
- arg: real beta
default: AS_REAL(1)
kwarg_only: True
- THTensor* self
- arg: real alpha
default: AS_REAL(1)
kwarg_only: True
- THTensor* batch1
- THTensor* batch2
]]
[[
name: addcmul
variants:
Expand Down
150 changes: 150 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include "ATen/ATen.h"
#include "ATen/ExpandUtils.h"
#include "ATen/Dispatch.h"
#include "ATen/NativeFunctions.h"
#include "ATen/native/LinearAlgebraUtils.h"
#include "ATen/TensorUtils.h"
#include "ATen/Parallel.h"
#include <functional>
#include <numeric>
#include <vector>
Expand Down Expand Up @@ -222,6 +225,153 @@ Tensor& addr_out(Tensor &result, const Tensor& self, const Tensor& vec1, const T
return at::_addr_out(result, self, vec1, vec2, beta, alpha);
}

template <typename scalar_t, bool is_bmm>
inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const Tensor& mat2, Scalar beta_, Scalar alpha_) {
int64_t bs = result.size(0);
int64_t is = result.size(1);
int64_t js = result.size(2);
int64_t ks = self.size(2);

scalar_t alpha = alpha_.to<scalar_t>();
scalar_t beta = beta_.to<scalar_t>();

auto r0 = result.accessor<scalar_t, 3>();
auto s0 = self.accessor<scalar_t, 3>();
auto m0 = mat2.accessor<scalar_t, 3>();

int64_t grain_size = std::min(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1);
parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
for (int64_t b = b_begin; b < b_end; b++) {
auto r1 = r0[b];
auto s1 = s0[b];
auto m1 = m0[b];
for (int64_t i = 0; i < is; i++) {
auto r2 = r1[i];
auto s2 = s1[i];
for (int64_t j = 0; j < js; j++) {
scalar_t &r = r2[j];
if (is_bmm) {
r = 0;
for (int64_t k = 0; k < ks; k++) {
r += s2[k] * m1[k][j];
}
} else {
r *= beta;
for (int64_t k = 0; k < ks; k++) {
r += alpha * s2[k] * m1[k][j];
}
}
}
}
}
});
}

// This tries to apply some optimizations to bmm/baddbmm:
// - When the operand size is small, computation are parallelized over the batch
// dimension using OMP and naive matrix multiplication is applied.
// - When the operand size is larger than the threshold, if compiled with MKL, MKL's batch gemm is used.
// - Otherwise, we use a series of matrix multiplications.
// The threshold of 400 for the first has not been thoroughly benchmarked yet and may have room for further
// optimization, it likely depends on the characteristics of the CPU, MKL will be different from non-MKL etc.,
// but this seems to be a first starting point.

static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha, bool is_bmm_out) {
// is_bmm_out: true for bmm_out, false for baddbmm_
// self_or_result is "self" for baddbmm_ and "result" for bmm_out
CheckedFrom c = (is_bmm_out ? "bmm" : "baddbmm");
TensorArg self_arg(self_or_result, is_bmm_out ? "self" : "result", 0);
TensorArg b1_arg(batch1, "batch1", 1);
TensorArg b2_arg(batch2, "batch2", 2);
checkDim(c, b1_arg, 3);
checkDim(c, b2_arg, 3);

int64_t bs = batch1.size(0);
checkSize(c, b2_arg, 0, bs);
int64_t contraction_size = batch1.size(2);
int64_t res_rows = batch1.size(1);
int64_t res_cols = batch2.size(2);
checkSize(c, b2_arg, 1, contraction_size);

if (is_bmm_out) {
self_or_result.resize_({bs, res_rows, res_cols});
} else {
checkSize(c, self_arg, 0, bs);
checkSize(c, self_arg, 1, res_rows);
checkSize(c, self_arg, 2, res_cols);
}

// handle pathological cases that blas may not like
if (self_or_result.numel() == 0) {
return self_or_result;
} else if (contraction_size == 0) {
return self_or_result.zero_();
}

auto batch_items_contiguous_or_transposed = [&](const Tensor& t) {
return (t.stride(2) == 1 && t.stride(1) == t.size(2))
|| (t.stride(1) == 1 && t.stride(2) == t.size(1));
};

if (contraction_size * res_rows * res_cols < 400) {
if (is_bmm_out) {
AT_DISPATCH_ALL_TYPES(batch1.type(), "bmm", [&] {
baddbmm_cpu_kernel<scalar_t, true>(self_or_result, batch1, batch2, beta, alpha);
});
} else {
AT_DISPATCH_ALL_TYPES(batch1.type(), "baddbmm", [&] {
baddbmm_cpu_kernel<scalar_t, false>(self_or_result, batch1, batch2, beta, alpha);
});
}
} else if (at::hasMKL() && at::native::is_floating_point(self_or_result)
&& batch_items_contiguous_or_transposed(batch1)
&& batch_items_contiguous_or_transposed(batch2)
&& self_or_result.is_contiguous()) {
at::native::_baddbmm_mkl_(self_or_result, batch1, batch2, beta, alpha);
} else { // split along batch dimension
if (is_bmm_out) {
for (int64_t b = 0; b < bs; b++) {
auto r = self_or_result.select(0, b);
at::native::mm_out(r, batch1.select(0, b), batch2.select(0, b));
}
} else {
for (int64_t b = 0; b < bs; b++) {
self_or_result.select(0, b).addmm_(batch1.select(0, b), batch2.select(0, b), beta, alpha);
}
}
}
return self_or_result;
}


Tensor baddbmm_cpu(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
Tensor result = self.type().tensor();
return at::native::baddbmm_out_cpu(result, self, batch1, batch2, beta, alpha);
}

Tensor& baddbmm_out_cpu(Tensor &result, const Tensor& self_, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
Tensor self;
std::tie(self) = expand_size(self_, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm");
result.resize_(self.sizes());
result.copy_(self);
return at::native::baddbmm__cpu(result, batch1, batch2, beta, alpha);
}

Tensor& baddbmm__cpu(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
return bmm_out_or_baddbmm_(self, batch1, batch2, beta, alpha, false);
}

Tensor bmm_cpu(const Tensor& self, const Tensor& mat2) {
Tensor result = self.type().tensor();
return at::native::bmm_out_cpu(result, self, mat2);
}

Tensor& bmm_out_cpu(Tensor &result, const Tensor& batch1, const Tensor& batch2) {
Scalar beta(0.0);
Scalar alpha(1.0);
return bmm_out_or_baddbmm_(result, batch1, batch2, beta, alpha, true);
}

Tensor dot(const Tensor& self, const Tensor& tensor) {
check_1d(self, "self", "dot");
check_1d(tensor, "tensor", "dot");
Expand Down
25 changes: 25 additions & 0 deletions aten/src/ATen/native/cuda/LinearAlgebra.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "ATen/ATen.h"

namespace at { namespace native {

Tensor baddbmm_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
return _th_baddbmm(self, batch1, batch2, beta, alpha);
}

Tensor& baddbmm_out_cuda(Tensor &result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
return _th_baddbmm_out(result, self, batch1, batch2, beta, alpha);
}

Tensor& baddbmm__cuda(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
return _th_baddbmm_out(self, self, batch1, batch2, beta, alpha);
}

Tensor bmm_cuda(const Tensor& self, const Tensor& mat2) {
return _th_bmm(self, mat2);
}

Tensor& bmm_out_cuda(Tensor &result, const Tensor& batch1, const Tensor& batch2) {
return _th_bmm_out(result, batch1, batch2);
}

} }
95 changes: 95 additions & 0 deletions aten/src/ATen/native/mkl/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/Config.h"

#if !AT_MKL_ENABLED()

namespace at { namespace native {

Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
AT_ERROR("bmm: ATen not compiled with MKL support");
}

}}

#else // AT_MKL_ENABLED

#include "ATen/ATen.h"
#include "ATen/Config.h"
#include "ATen/Dispatch.h"
#include "ATen/Utils.h"
#include "ATen/NativeFunctions.h"

#include <algorithm>
#include <vector>
#include <numeric>
#include <cmath>

#include <mkl.h>
#include <ATen/mkl/Exceptions.h>
#include <ATen/mkl/Descriptors.h>
#include <ATen/mkl/Limits.h>

namespace at { namespace native {

static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int batch_size, const int M, const int N, const int K, const float alpha,
const float** A, const float** B, const float beta, float** C) {
const int lda = (trans_A == CblasNoTrans) ? K : M;
const int ldb = (trans_B == CblasNoTrans) ? N : K;
const int ldc = N;

cblas_sgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
}

static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int batch_size, const int M, const int N, const int K, const double alpha,
const double** A, const double** B, const double beta, double** C) {
const int lda = (trans_A == CblasNoTrans) ? K : M;
const int ldb = (trans_B == CblasNoTrans) ? N : K;
const int ldc = N;

cblas_dgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
}

template <typename scalar_t>
static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, const Tensor& mat2, Scalar beta_, Scalar alpha_) {
auto is_transposed = [&](const Tensor& t) {
return t.stride(0) == 1 && t.stride(1) == t.size(0);
};
const CBLAS_TRANSPOSE trans_A = is_transposed(mat1[0]) ? CblasTrans : CblasNoTrans;
const CBLAS_TRANSPOSE trans_B = is_transposed(mat2[0]) ? CblasTrans : CblasNoTrans;

const int batch_size = mat1.size(0);
const int M = mat1.size(1);
const int N = mat2.size(2);
const int K = mat1.size(2);
scalar_t alpha = alpha_.to<scalar_t>();
scalar_t beta = beta_.to<scalar_t>();

std::vector<const scalar_t*> A(batch_size);
std::vector<const scalar_t*> B(batch_size);
std::vector<scalar_t*> C(batch_size);
for (int64_t batch = 0; batch < batch_size; batch++) {
A[batch] = mat1[batch].data<scalar_t>();
B[batch] = mat2[batch].data<scalar_t>();
C[batch] = res[batch].data<scalar_t>();
}

gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), B.data(), beta, C.data());
}

Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
// checks are done in native/LinearAlgebra.cpp
AT_DISPATCH_FLOATING_TYPES(self.type(), "baddbmm__mkl", [&] {
baddbmm_mkl_template<scalar_t>(self, batch1, batch2, beta, alpha);
});

return self;
}

}} // namespace at::native

#endif
33 changes: 33 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,27 @@
CPU: _atan_out_cpu
CUDA: _atan_out_cuda

- func: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: function, method
dispatch:
CPU: baddbmm_cpu
CUDA: baddbmm_cuda

- func: baddbmm_(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: method
dispatch:
CPU: baddbmm__cpu
CUDA: baddbmm__cuda

- func: _baddbmm_mkl_(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: function

- func: baddbmm_out(Tensor result, Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: function
dispatch:
CPU: baddbmm_out_cpu
CUDA: baddbmm_out_cuda

- func: bartlett_window(int64_t window_length, TensorOptions options={}) -> Tensor

- func: bartlett_window(int64_t window_length, bool periodic, TensorOptions options={}) -> Tensor
Expand Down Expand Up @@ -281,6 +302,18 @@

- func: blackman_window(int64_t window_length, bool periodic, TensorOptions options={}) -> Tensor

- func: bmm(Tensor self, Tensor mat2) -> Tensor
variants: function, method
dispatch:
CPU: bmm_cpu
CUDA: bmm_cuda

- func: bmm_out(Tensor result, Tensor self, Tensor mat2) -> Tensor
variants: function
dispatch:
CPU: bmm_out_cpu
CUDA: bmm_out_cuda

- func: broadcast_tensors(TensorList tensors) -> TensorList

- func: cat(TensorList tensors, int64_t dim=0) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def tmp(t):
types, False, "skipIfRocm:HalfTensor"),
('baddbmm', small_3d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars',
types, False, "skipIfRocm:HalfTensor"),
('bmm', small_3d, lambda t: [small_3d(t)], '', float_types_no_half, False, "skipIfRocm:HalfTensor"),
('addcdiv', small_2d_lapack, lambda t: [tensor_mul(small_2d_lapack(t), 2), small_2d_lapack(t)], '',
types, False, "skipIfRocm:HalfTensor"),
('addcdiv', small_2d_lapack, lambda t: [number(2.8, 1, t), tensor_mul(small_2d_lapack(t), 2), small_2d_lapack(t)],
Expand Down