@@ -437,6 +437,16 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
437437 }
438438}
439439
440+ __global__ void createBatchGemmBuffer3 (const real** buffer1, const real ** buffer2, const real ** buffer3, real* data1,
441+ real * data2, real * data3, long stride1, long stride2, long stride3, long num_batches) {
442+ const long idx = blockIdx .x * blockDim .x + threadIdx .x ;
443+ if (idx < num_batches) {
444+ buffer1[idx] = data1 + idx * stride1;
445+ buffer2[idx] = data2 + idx * stride2;
446+ buffer3[idx] = data3 + idx * stride3;
447+ }
448+ }
449+
440450THC_API void
441451THCTensor_ (baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
442452 real alpha, THCTensor *batch1, THCTensor *batch2) {
@@ -554,15 +564,11 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
554564 const int64_t block = 512 ;
555565 const int64_t grid = (num_batches + block - 1 ) / block;
556566
557- createBatchGemmBuffer<<<grid, block, 0 , THCState_getCurrentStream(state)>>> (
558- d_matrices1, THCTensor_ (data)(state, batch1_), batch1_->stride [0 ],
559- num_batches);
560- createBatchGemmBuffer<<<grid, block, 0 , THCState_getCurrentStream(state)>>> (
561- d_matrices2, THCTensor_ (data)(state, batch2_), batch2_->stride [0 ],
562- num_batches);
563- createBatchGemmBuffer<<<grid, block, 0 , THCState_getCurrentStream(state)>>> (
564- (const real**)d_result_matrices, THCTensor_ (data)(state,result_),
565- result_->stride [0 ], num_batches);
567+ createBatchGemmBuffer3<<<grid, block, 0 , THCState_getCurrentStream(state)>>> (
568+ d_matrices1, d_matrices2, (const real**)d_result_matrices, THCTensor_ (data)(state, batch1_),
569+ THCTensor_ (data)(state, batch2_), THCTensor_ (data)(state, result_),
570+ batch1_->stride [0 ], batch2_->stride [0 ], result_->stride [0 ], num_batches);
571+
566572#ifdef THC_REAL_IS_FLOAT
567573 THCudaBlas_SgemmBatched (
568574 state,
@@ -591,7 +597,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
591597 beta,
592598 d_result_matrices, ldc,
593599 num_batches);
594- #endif
600+ #endif // THC_REAL
595601
596602 THCudaFree (state, d_matrices1);
597603 THCudaFree (state, d_matrices2);
@@ -626,10 +632,12 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
626632 beta,
627633 THCTensor_ (data)(state, result_), ldc, result_->stride [0 ],
628634 num_batches);
629- #endif
630- #endif
635+ #endif // THC_REAL
636+ #endif // CUDA_VERSION
631637
632638#elif defined(THC_REAL_IS_HALF)
639+
640+ #if CUDA_VERSION < 9100
633641 // Currently no HgemmBatched in Cublas
634642 for (int64_t i = 0 ; i < num_batches; ++i) {
635643 THCudaBlas_Hgemm (
@@ -645,8 +653,42 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
645653 beta,
646654 THCTensor_ (data)(state, result_) + i * result_->stride [0 ], ldc);
647655 }
648- #endif
656+ #else
657+ cudaDeviceProp* prop = THCState_getCurrentDeviceProperties (state);
658+ if (prop->major >= 5 ){
659+
660+ THCudaBlas_HgemmStridedBatched (
661+ state,
662+ transpose_batch1,
663+ transpose_batch2,
664+ result_->size [transpose_result ? 2 : 1 ],
665+ result_->size [transpose_result ? 1 : 2 ],
666+ batch1_->size [transpose_result ? 1 : 2 ],
667+ alpha,
668+ THCTensor_ (data)(state, batch1_), lda, batch1_->stride [0 ],
669+ THCTensor_ (data)(state, batch2_), ldb, batch2_->stride [0 ],
670+ beta,
671+ THCTensor_ (data)(state, result_), ldc, result_->stride [0 ],
672+ num_batches);
673+ } else {
674+ for (long i = 0 ; i < num_batches; ++i) {
675+ THCudaBlas_Hgemm (
676+ state,
677+ transpose_batch1,
678+ transpose_batch2,
679+ result_->size [transpose_result ? 2 : 1 ],
680+ result_->size [transpose_result ? 1 : 2 ],
681+ batch1_->size [transpose_result ? 1 : 2 ],
682+ alpha,
683+ THCTensor_ (data)(state, batch1_) + i * batch1_->stride [0 ], lda,
684+ THCTensor_ (data)(state, batch2_) + i * batch2_->stride [0 ], ldb,
685+ beta,
686+ THCTensor_ (data)(state, result_) + i * result_->stride [0 ], ldc);
687+ }
688+ }
649689
690+ #endif
691+ #endif
650692 if (batch1_ != batch1) {
651693 THCTensor_ (free )(state, batch1_);
652694 }
0 commit comments