Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions aten/src/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,36 @@ void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int6
"with the bound [val] <= %d", INT_MAX);
}

#if CUDA_VERSION >= 9100
void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
half beta, half *c, long ldc, long strideC, long batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )

{
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}

adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);

cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
float fAlpha = THC_half2float(alpha);
float fBeta = THC_half2float(beta);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
THCublasCheck(cublasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
b, CUDA_R_16F, (int)ldb, strideB,
(void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,
(int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#endif

void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
Expand Down
7 changes: 7 additions & 0 deletions aten/src/THC/THCBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char t
double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB,
double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount);
#endif

#if CUDA_VERSION >= 9100
void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
half beta, half *c, long ldc, long strideC, long batchCount);
#endif

/* Inverse */
THC_API void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize);
THC_API void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize);
Expand Down
68 changes: 55 additions & 13 deletions aten/src/THC/generic/THCTensorMathBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,16 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
}
}

__global__ void createBatchGemmBuffer3(const real** buffer1, const real ** buffer2, const real ** buffer3, real* data1,
real * data2, real * data3, long stride1, long stride2, long stride3, long num_batches) {
const long idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_batches) {
buffer1[idx] = data1 + idx * stride1;
buffer2[idx] = data2 + idx * stride2;
buffer3[idx] = data3 + idx * stride3;
}
}

THC_API void
THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
real alpha, THCTensor *batch1, THCTensor *batch2) {
Expand Down Expand Up @@ -551,15 +561,11 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
const int64_t block = 512;
const int64_t grid = (num_batches + block - 1) / block;

createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
d_matrices1, THCTensor_(data)(state, batch1_), batch1_->stride[0],
num_batches);
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
d_matrices2, THCTensor_(data)(state, batch2_), batch2_->stride[0],
num_batches);
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
(const real**)d_result_matrices, THCTensor_(data)(state,result_),
result_->stride[0], num_batches);
createBatchGemmBuffer3<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
d_matrices1, d_matrices2, (const real**)d_result_matrices, THCTensor_(data)(state, batch1_),
THCTensor_(data)(state, batch2_), THCTensor_(data)(state, result_),
batch1_->stride[0], batch2_->stride[0], result_->stride[0], num_batches);

#ifdef THC_REAL_IS_FLOAT
THCudaBlas_SgemmBatched(
state,
Expand Down Expand Up @@ -588,7 +594,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
beta,
d_result_matrices, ldc,
num_batches);
#endif
#endif //THC_REAL

THCudaFree(state, d_matrices1);
THCudaFree(state, d_matrices2);
Expand Down Expand Up @@ -623,10 +629,12 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
beta,
THCTensor_(data)(state, result_), ldc, result_->stride[0],
num_batches);
#endif
#endif
#endif //THC_REAL
#endif //CUDA_VERSION

#elif defined(THC_REAL_IS_HALF)

#if CUDA_VERSION < 9100
// Currently no HgemmBatched in Cublas
for (int64_t i = 0; i < num_batches; ++i) {
THCudaBlas_Hgemm(
Expand All @@ -642,8 +650,42 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
beta,
THCTensor_(data)(state, result_) + i * result_->stride[0], ldc);
}
#endif
#else
cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state);
if (prop->major >= 5){

THCudaBlas_HgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result_->size[transpose_result ? 2 : 1],
result_->size[transpose_result ? 1 : 2],
batch1_->size[transpose_result ? 1 : 2],
alpha,
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
beta,
THCTensor_(data)(state, result_), ldc, result_->stride[0],
num_batches);
} else {
for (long i = 0; i < num_batches; ++i) {
THCudaBlas_Hgemm(
state,
transpose_batch1,
transpose_batch2,
result_->size[transpose_result ? 2 : 1],
result_->size[transpose_result ? 1 : 2],
batch1_->size[transpose_result ? 1 : 2],
alpha,
THCTensor_(data)(state, batch1_) + i * batch1_->stride[0], lda,
THCTensor_(data)(state, batch2_) + i * batch2_->stride[0], ldb,
beta,
THCTensor_(data)(state, result_) + i * result_->stride[0], ldc);
}
}

#endif
#endif
if (batch1_ != batch1) {
THCTensor_(free)(state, batch1_);
}
Expand Down