@@ -24,10 +24,11 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
2424 THCTensor_ (data)(state, self), 1 ,
2525 THCTensor_ (data)(state, src), 1 );
2626#elif defined(THC_REAL_IS_HALF)
27- accreal result = THCudaBlas_Hdot (state,
27+ accreal result = ScalarConvert<half, accreal>::to (
28+ THCudaBlas_Hdot (state,
2829 THCTensor_ (nElement)(state, self),
2930 THCTensor_ (data)(state, self), 1 ,
30- THCTensor_ (data)(state, src), 1 );
31+ THCTensor_ (data)(state, src), 1 )) ;
3132#endif
3233
3334 THCTensor_ (free )(state, src);
@@ -492,14 +493,14 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
492493 ldc = result_->stride [2 ];
493494 }
494495
495- if (batch1->stride [transpose_result ? 2 : 1 ] == 1 &&
496+ if (batch1->stride [transpose_result ? 2 : 1 ] == 1 &&
496497 batch1->stride [transpose_result ? 1 : 2 ] != 0 )
497498 {
498499 transpose_batch1 = ' n' ;
499500 batch1_ = batch1;
500501 lda = batch1_->stride [transpose_result ? 1 : 2 ];
501502 }
502- else if (batch1->stride [transpose_result ? 1 : 2 ] == 1 &&
503+ else if (batch1->stride [transpose_result ? 1 : 2 ] == 1 &&
503504 batch1->stride [transpose_result ? 2 : 1 ] != 0 )
504505 {
505506 transpose_batch1 = ' t' ;
@@ -513,7 +514,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
513514 lda = batch1_->stride [1 ];
514515 }
515516
516- if (batch2->stride [transpose_result ? 2 : 1 ] == 1 &&
517+ if (batch2->stride [transpose_result ? 2 : 1 ] == 1 &&
517518 batch2->stride [transpose_result ? 1 : 2 ] != 0 )
518519 {
519520 transpose_batch2 = ' n' ;
@@ -537,7 +538,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
537538
538539#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
539540 // Compute pointers to matrices in each batch.
540- #if CUDA_VERSION < 8000
541+ #if CUDA_VERSION < 8000
541542 size_t matrices_size = num_batches * sizeof (real*);
542543
543544// Copy pointers to device.
@@ -592,7 +593,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
592593 THCudaFree (state, d_matrices1);
593594 THCudaFree (state, d_matrices2);
594595 THCudaFree (state, d_result_matrices);
595-
596+
596597#else
597598#ifdef THC_REAL_IS_FLOAT
598599 THCudaBlas_SgemmStridedBatched (
@@ -606,7 +607,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
606607 THCTensor_ (data)(state, batch1_), lda, batch1_->stride [0 ],
607608 THCTensor_ (data)(state, batch2_), ldb, batch2_->stride [0 ],
608609 beta,
609- THCTensor_ (data)(state, result_), ldc, result_->stride [0 ],
610+ THCTensor_ (data)(state, result_), ldc, result_->stride [0 ],
610611 num_batches);
611612#elif defined(THC_REAL_IS_DOUBLE)
612613 THCudaBlas_DgemmStridedBatched (
@@ -620,7 +621,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
620621 THCTensor_ (data)(state, batch1_), lda, batch1_->stride [0 ],
621622 THCTensor_ (data)(state, batch2_), ldb, batch2_->stride [0 ],
622623 beta,
623- THCTensor_ (data)(state, result_), ldc, result_->stride [0 ],
624+ THCTensor_ (data)(state, result_), ldc, result_->stride [0 ],
624625 num_batches);
625626#endif
626627#endif
0 commit comments