Skip to content

Commit 9e18db9

Browse files
csarofeensoumith
authored andcommitted
Add cublas batched gemm support. (#4151)
* Add cublas batched gemm. * Comment cleanup batched gemm. * Fix cuda versioning batched gemm.
1 parent 84da898 commit 9e18db9

File tree

3 files changed

+92
-13
lines changed

3 files changed

+92
-13
lines changed

torch/lib/THC/THCBlas.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,36 @@ void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int6
332332
"with the bound [val] <= %d", INT_MAX);
333333
}
334334

335+
#if CUDA_VERSION >= 9100
336+
void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
337+
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
338+
half beta, half *c, long ldc, long strideC, long batchCount)
339+
{
340+
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
341+
342+
{
343+
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
344+
"with the bound [val] <= %d", INT_MAX);
345+
}
346+
347+
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
348+
cublasOperation_t opa = convertTransToCublasOperation(transa);
349+
cublasOperation_t opb = convertTransToCublasOperation(transb);
350+
351+
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
352+
cublasSetStream(handle, THCState_getCurrentStream(state));
353+
float fAlpha = THC_half2float(alpha);
354+
float fBeta = THC_half2float(beta);
355+
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
356+
THCublasCheck(cublasGemmStridedBatchedEx(handle,
357+
opa, opb, (int)m, (int)n, (int)k,
358+
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
359+
b, CUDA_R_16F, (int)ldb, strideB,
360+
(void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,
361+
(int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
362+
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
363+
}
364+
#endif
335365

336366
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
337367
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,

torch/lib/THC/THCBlas.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char t
3939
double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB,
4040
double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount);
4141
#endif
42+
43+
#if CUDA_VERSION >= 9100
44+
void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
45+
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
46+
half beta, half *c, long ldc, long strideC, long batchCount);
47+
#endif
48+
4249
/* Inverse */
4350
THC_API void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize);
4451
THC_API void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize);

torch/lib/THC/generic/THCTensorMathBlas.cu

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,16 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
434434
}
435435
}
436436

437+
__global__ void createBatchGemmBuffer3(const real** buffer1, const real ** buffer2, const real ** buffer3, real* data1,
438+
real * data2, real * data3, long stride1, long stride2, long stride3, long num_batches) {
439+
const long idx = blockIdx.x * blockDim.x + threadIdx.x;
440+
if (idx < num_batches) {
441+
buffer1[idx] = data1 + idx * stride1;
442+
buffer2[idx] = data2 + idx * stride2;
443+
buffer3[idx] = data3 + idx * stride3;
444+
}
445+
}
446+
437447
THC_API void
438448
THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
439449
real alpha, THCTensor *batch1, THCTensor *batch2) {
@@ -551,15 +561,11 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
551561
const int64_t block = 512;
552562
const int64_t grid = (num_batches + block - 1) / block;
553563

554-
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
555-
d_matrices1, THCTensor_(data)(state, batch1_), batch1_->stride[0],
556-
num_batches);
557-
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
558-
d_matrices2, THCTensor_(data)(state, batch2_), batch2_->stride[0],
559-
num_batches);
560-
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
561-
(const real**)d_result_matrices, THCTensor_(data)(state,result_),
562-
result_->stride[0], num_batches);
564+
createBatchGemmBuffer3<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
565+
d_matrices1, d_matrices2, (const real**)d_result_matrices, THCTensor_(data)(state, batch1_),
566+
THCTensor_(data)(state, batch2_), THCTensor_(data)(state, result_),
567+
batch1_->stride[0], batch2_->stride[0], result_->stride[0], num_batches);
568+
563569
#ifdef THC_REAL_IS_FLOAT
564570
THCudaBlas_SgemmBatched(
565571
state,
@@ -588,7 +594,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
588594
beta,
589595
d_result_matrices, ldc,
590596
num_batches);
591-
#endif
597+
#endif //THC_REAL
592598

593599
THCudaFree(state, d_matrices1);
594600
THCudaFree(state, d_matrices2);
@@ -623,10 +629,12 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
623629
beta,
624630
THCTensor_(data)(state, result_), ldc, result_->stride[0],
625631
num_batches);
626-
#endif
627-
#endif
632+
#endif //THC_REAL
633+
#endif //CUDA_VERSION
628634

629635
#elif defined(THC_REAL_IS_HALF)
636+
637+
#if CUDA_VERSION < 9100
630638
// Currently no HgemmBatched in Cublas
631639
for (int64_t i = 0; i < num_batches; ++i) {
632640
THCudaBlas_Hgemm(
@@ -642,8 +650,42 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
642650
beta,
643651
THCTensor_(data)(state, result_) + i * result_->stride[0], ldc);
644652
}
645-
#endif
653+
#else
654+
cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state);
655+
if (prop->major >= 5){
656+
657+
THCudaBlas_HgemmStridedBatched(
658+
state,
659+
transpose_batch1,
660+
transpose_batch2,
661+
result_->size[transpose_result ? 2 : 1],
662+
result_->size[transpose_result ? 1 : 2],
663+
batch1_->size[transpose_result ? 1 : 2],
664+
alpha,
665+
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
666+
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
667+
beta,
668+
THCTensor_(data)(state, result_), ldc, result_->stride[0],
669+
num_batches);
670+
} else {
671+
for (long i = 0; i < num_batches; ++i) {
672+
THCudaBlas_Hgemm(
673+
state,
674+
transpose_batch1,
675+
transpose_batch2,
676+
result_->size[transpose_result ? 2 : 1],
677+
result_->size[transpose_result ? 1 : 2],
678+
batch1_->size[transpose_result ? 1 : 2],
679+
alpha,
680+
THCTensor_(data)(state, batch1_) + i * batch1_->stride[0], lda,
681+
THCTensor_(data)(state, batch2_) + i * batch2_->stride[0], ldb,
682+
beta,
683+
THCTensor_(data)(state, result_) + i * result_->stride[0], ldc);
684+
}
685+
}
646686

687+
#endif
688+
#endif
647689
if (batch1_ != batch1) {
648690
THCTensor_(free)(state, batch1_);
649691
}

0 commit comments

Comments
 (0)