@@ -434,6 +434,16 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
434434 }
435435}
436436
437+ __global__ void createBatchGemmBuffer3 (const real** buffer1, const real ** buffer2, const real ** buffer3, real* data1,
438+ real * data2, real * data3, long stride1, long stride2, long stride3, long num_batches) {
439+ const long idx = blockIdx .x * blockDim .x + threadIdx .x ;
440+ if (idx < num_batches) {
441+ buffer1[idx] = data1 + idx * stride1;
442+ buffer2[idx] = data2 + idx * stride2;
443+ buffer3[idx] = data3 + idx * stride3;
444+ }
445+ }
446+
437447THC_API void
438448THCTensor_ (baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
439449 real alpha, THCTensor *batch1, THCTensor *batch2) {
@@ -551,15 +561,11 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
551561 const int64_t block = 512 ;
552562 const int64_t grid = (num_batches + block - 1 ) / block;
553563
554- createBatchGemmBuffer<<<grid, block, 0 , THCState_getCurrentStream(state)>>> (
555- d_matrices1, THCTensor_ (data)(state, batch1_), batch1_->stride [0 ],
556- num_batches);
557- createBatchGemmBuffer<<<grid, block, 0 , THCState_getCurrentStream(state)>>> (
558- d_matrices2, THCTensor_ (data)(state, batch2_), batch2_->stride [0 ],
559- num_batches);
560- createBatchGemmBuffer<<<grid, block, 0 , THCState_getCurrentStream(state)>>> (
561- (const real**)d_result_matrices, THCTensor_ (data)(state,result_),
562- result_->stride [0 ], num_batches);
564+ createBatchGemmBuffer3<<<grid, block, 0 , THCState_getCurrentStream(state)>>> (
565+ d_matrices1, d_matrices2, (const real**)d_result_matrices, THCTensor_ (data)(state, batch1_),
566+ THCTensor_ (data)(state, batch2_), THCTensor_ (data)(state, result_),
567+ batch1_->stride [0 ], batch2_->stride [0 ], result_->stride [0 ], num_batches);
568+
563569#ifdef THC_REAL_IS_FLOAT
564570 THCudaBlas_SgemmBatched (
565571 state,
@@ -588,7 +594,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
588594 beta,
589595 d_result_matrices, ldc,
590596 num_batches);
591- #endif
597+ #endif // THC_REAL
592598
593599 THCudaFree (state, d_matrices1);
594600 THCudaFree (state, d_matrices2);
@@ -623,10 +629,12 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
623629 beta,
624630 THCTensor_ (data)(state, result_), ldc, result_->stride [0 ],
625631 num_batches);
626- #endif
627- #endif
632+ #endif // THC_REAL
633+ #endif // CUDA_VERSION
628634
629635#elif defined(THC_REAL_IS_HALF)
636+
637+ #if CUDA_VERSION < 9100
630638 // Currently no HgemmBatched in Cublas
631639 for (int64_t i = 0 ; i < num_batches; ++i) {
632640 THCudaBlas_Hgemm (
@@ -642,8 +650,42 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
642650 beta,
643651 THCTensor_ (data)(state, result_) + i * result_->stride [0 ], ldc);
644652 }
645- #endif
653+ #else
654+ cudaDeviceProp* prop = THCState_getCurrentDeviceProperties (state);
655+ if (prop->major >= 5 ){
656+
657+ THCudaBlas_HgemmStridedBatched (
658+ state,
659+ transpose_batch1,
660+ transpose_batch2,
661+ result_->size [transpose_result ? 2 : 1 ],
662+ result_->size [transpose_result ? 1 : 2 ],
663+ batch1_->size [transpose_result ? 1 : 2 ],
664+ alpha,
665+ THCTensor_ (data)(state, batch1_), lda, batch1_->stride [0 ],
666+ THCTensor_ (data)(state, batch2_), ldb, batch2_->stride [0 ],
667+ beta,
668+ THCTensor_ (data)(state, result_), ldc, result_->stride [0 ],
669+ num_batches);
670+ } else {
671+ for (long i = 0 ; i < num_batches; ++i) {
672+ THCudaBlas_Hgemm (
673+ state,
674+ transpose_batch1,
675+ transpose_batch2,
676+ result_->size [transpose_result ? 2 : 1 ],
677+ result_->size [transpose_result ? 1 : 2 ],
678+ batch1_->size [transpose_result ? 1 : 2 ],
679+ alpha,
680+ THCTensor_ (data)(state, batch1_) + i * batch1_->stride [0 ], lda,
681+ THCTensor_ (data)(state, batch2_) + i * batch2_->stride [0 ], ldb,
682+ beta,
683+ THCTensor_ (data)(state, result_) + i * result_->stride [0 ], ldc);
684+ }
685+ }
646686
687+ #endif
688+ #endif
647689 if (batch1_ != batch1) {
648690 THCTensor_ (free )(state, batch1_);
649691 }
0 commit comments