Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions aten/src/ATen/native/mkl/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,36 @@ static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANS

template <typename scalar_t>
static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, const Tensor& mat2, Scalar beta_, Scalar alpha_) {
auto is_transposed = [&](const Tensor& t) {
auto is_transposed = [&](const TensorAccessor<scalar_t, 2>& t) {
return t.stride(0) == 1 && t.stride(1) >= t.size(0);
};
const CBLAS_TRANSPOSE trans_A = is_transposed(mat1[0]) ? CblasTrans : CblasNoTrans;
const CBLAS_TRANSPOSE trans_B = is_transposed(mat2[0]) ? CblasTrans : CblasNoTrans;

const int batch_size = mat1.size(0);
const int M = mat1.size(1);
const int N = mat2.size(2);
const int K = mat1.size(2);
auto mat1_acc = mat1.accessor<scalar_t, 3>();
auto mat2_acc = mat2.accessor<scalar_t, 3>();
auto res_acc = res.accessor<scalar_t, 3>();

const CBLAS_TRANSPOSE trans_A = is_transposed(mat1_acc[0]) ? CblasTrans : CblasNoTrans;
const CBLAS_TRANSPOSE trans_B = is_transposed(mat2_acc[0]) ? CblasTrans : CblasNoTrans;

const int batch_size = mat1_acc.size(0);
const int M = mat1_acc.size(1);
const int N = mat2_acc.size(2);
const int K = mat1_acc.size(2);
scalar_t alpha = alpha_.to<scalar_t>();
scalar_t beta = beta_.to<scalar_t>();

const int lda = is_transposed(mat1[0]) ? mat1[0].stride(1) : mat1[0].stride(0);
const int ldb = is_transposed(mat2[0]) ? mat2[0].stride(1) : mat2[0].stride(0);
const int lda = is_transposed(mat1_acc[0]) ? mat1_acc[0].stride(1) : mat1_acc[0].stride(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line over 80 characters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

up to 119 is okay

const int ldb = is_transposed(mat2_acc[0]) ? mat2_acc[0].stride(1) : mat2_acc[0].stride(0);
const int ldc = res[0].stride(0);

std::vector<const scalar_t*> A(batch_size);
std::vector<const scalar_t*> B(batch_size);
std::vector<scalar_t*> C(batch_size);

for (int64_t batch = 0; batch < batch_size; batch++) {
A[batch] = mat1[batch].data<scalar_t>();
B[batch] = mat2[batch].data<scalar_t>();
C[batch] = res[batch].data<scalar_t>();
A[batch] = mat1_acc[batch].data();
B[batch] = mat2_acc[batch].data();
C[batch] = res_acc[batch].data();
}

gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc);
Expand Down