Skip to content

Commit 5a264b4

Browse files
csarofeensoumith
authored andcommitted
Add cublas batched gemm support. (#4151)
* Add cublas batched gemm. * Comment cleanup batched gemm. * Fix cuda versioning batched gemm.
1 parent fac711c commit 5a264b4

File tree

3 files changed

+92
-13
lines changed

3 files changed

+92
-13
lines changed

aten/src/THC/THCBlas.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,36 @@ void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int6
332332
"with the bound [val] <= %d", INT_MAX);
333333
}
334334

335+
#if CUDA_VERSION >= 9100
336+
void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
337+
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
338+
half beta, half *c, long ldc, long strideC, long batchCount)
339+
{
340+
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
341+
342+
{
343+
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
344+
"with the bound [val] <= %d", INT_MAX);
345+
}
346+
347+
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
348+
cublasOperation_t opa = convertTransToCublasOperation(transa);
349+
cublasOperation_t opb = convertTransToCublasOperation(transb);
350+
351+
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
352+
cublasSetStream(handle, THCState_getCurrentStream(state));
353+
float fAlpha = THC_half2float(alpha);
354+
float fBeta = THC_half2float(beta);
355+
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
356+
THCublasCheck(cublasGemmStridedBatchedEx(handle,
357+
opa, opb, (int)m, (int)n, (int)k,
358+
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
359+
b, CUDA_R_16F, (int)ldb, strideB,
360+
(void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,
361+
(int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
362+
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
363+
}
364+
#endif
335365

336366
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
337367
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,

aten/src/THC/THCBlas.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char t
3939
double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB,
4040
double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount);
4141
#endif
42+
43+
#if CUDA_VERSION >= 9100
44+
void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
45+
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
46+
half beta, half *c, long ldc, long strideC, long batchCount);
47+
#endif
48+
4249
/* Inverse */
4350
THC_API void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize);
4451
THC_API void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize);

aten/src/THC/generic/THCTensorMathBlas.cu

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,16 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
437437
}
438438
}
439439

440+
__global__ void createBatchGemmBuffer3(const real** buffer1, const real ** buffer2, const real ** buffer3, real* data1,
441+
real * data2, real * data3, long stride1, long stride2, long stride3, long num_batches) {
442+
const long idx = blockIdx.x * blockDim.x + threadIdx.x;
443+
if (idx < num_batches) {
444+
buffer1[idx] = data1 + idx * stride1;
445+
buffer2[idx] = data2 + idx * stride2;
446+
buffer3[idx] = data3 + idx * stride3;
447+
}
448+
}
449+
440450
THC_API void
441451
THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
442452
real alpha, THCTensor *batch1, THCTensor *batch2) {
@@ -554,15 +564,11 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
554564
const int64_t block = 512;
555565
const int64_t grid = (num_batches + block - 1) / block;
556566

557-
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
558-
d_matrices1, THCTensor_(data)(state, batch1_), batch1_->stride[0],
559-
num_batches);
560-
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
561-
d_matrices2, THCTensor_(data)(state, batch2_), batch2_->stride[0],
562-
num_batches);
563-
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
564-
(const real**)d_result_matrices, THCTensor_(data)(state,result_),
565-
result_->stride[0], num_batches);
567+
createBatchGemmBuffer3<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
568+
d_matrices1, d_matrices2, (const real**)d_result_matrices, THCTensor_(data)(state, batch1_),
569+
THCTensor_(data)(state, batch2_), THCTensor_(data)(state, result_),
570+
batch1_->stride[0], batch2_->stride[0], result_->stride[0], num_batches);
571+
566572
#ifdef THC_REAL_IS_FLOAT
567573
THCudaBlas_SgemmBatched(
568574
state,
@@ -591,7 +597,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
591597
beta,
592598
d_result_matrices, ldc,
593599
num_batches);
594-
#endif
600+
#endif //THC_REAL
595601

596602
THCudaFree(state, d_matrices1);
597603
THCudaFree(state, d_matrices2);
@@ -626,10 +632,12 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
626632
beta,
627633
THCTensor_(data)(state, result_), ldc, result_->stride[0],
628634
num_batches);
629-
#endif
630-
#endif
635+
#endif //THC_REAL
636+
#endif //CUDA_VERSION
631637

632638
#elif defined(THC_REAL_IS_HALF)
639+
640+
#if CUDA_VERSION < 9100
633641
// Currently no HgemmBatched in Cublas
634642
for (int64_t i = 0; i < num_batches; ++i) {
635643
THCudaBlas_Hgemm(
@@ -645,8 +653,42 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
645653
beta,
646654
THCTensor_(data)(state, result_) + i * result_->stride[0], ldc);
647655
}
648-
#endif
656+
#else
657+
cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state);
658+
if (prop->major >= 5){
659+
660+
THCudaBlas_HgemmStridedBatched(
661+
state,
662+
transpose_batch1,
663+
transpose_batch2,
664+
result_->size[transpose_result ? 2 : 1],
665+
result_->size[transpose_result ? 1 : 2],
666+
batch1_->size[transpose_result ? 1 : 2],
667+
alpha,
668+
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
669+
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
670+
beta,
671+
THCTensor_(data)(state, result_), ldc, result_->stride[0],
672+
num_batches);
673+
} else {
674+
for (long i = 0; i < num_batches; ++i) {
675+
THCudaBlas_Hgemm(
676+
state,
677+
transpose_batch1,
678+
transpose_batch2,
679+
result_->size[transpose_result ? 2 : 1],
680+
result_->size[transpose_result ? 1 : 2],
681+
batch1_->size[transpose_result ? 1 : 2],
682+
alpha,
683+
THCTensor_(data)(state, batch1_) + i * batch1_->stride[0], lda,
684+
THCTensor_(data)(state, batch2_) + i * batch2_->stride[0], ldb,
685+
beta,
686+
THCTensor_(data)(state, result_) + i * result_->stride[0], ldc);
687+
}
688+
}
649689

690+
#endif
691+
#endif
650692
if (batch1_ != batch1) {
651693
THCTensor_(free)(state, batch1_);
652694
}

0 commit comments

Comments
 (0)