Skip to content

Commit d27c3ce

Browse files
apaszkesoumith
authored andcommitted
Fix cuBLAS arguments for fp16 dot (#3660)
* Fix cuBLAS arguments for fp16 dot * Enable FloatTensor <-> CUDA HalfTensor checks in test_cuda.py
1 parent 280bf09 commit d27c3ce

File tree

4 files changed

+42
-26
lines changed

4 files changed

+42
-26
lines changed

test/test_cuda.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,18 @@ def get_cycles_per_ms():
384384
return _cycles_per_ms
385385

386386

387-
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
387+
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5, force_gpu_half=False):
388388
def tmp(self):
389389
cpu_tensor = tensor_constructor(t)
390-
gpu_tensor = to_gpu(cpu_tensor)
390+
type_map = {}
391+
if force_gpu_half:
392+
type_map = {
393+
'torch.FloatTensor': 'torch.cuda.HalfTensor',
394+
'torch.DoubleTensor': 'torch.cuda.HalfTensor',
395+
}
396+
gpu_tensor = to_gpu(cpu_tensor, type_map)
391397
cpu_args = arg_constructor(t)
392-
gpu_args = [to_gpu(arg) for arg in cpu_args]
398+
gpu_args = [to_gpu(arg, type_map) for arg in cpu_args]
393399
cpu_result = getattr(cpu_tensor, fn)(*cpu_args)
394400
try:
395401
gpu_result = getattr(gpu_tensor, fn)(*gpu_args)
@@ -1096,7 +1102,15 @@ def test_nvtx(self):
10961102
test_name += '_' + desc
10971103

10981104
assert not hasattr(TestCuda, test_name), "Duplicated test name: " + test_name
1099-
setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
1105+
setattr(TestCuda,
1106+
test_name,
1107+
compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
1108+
if t == torch.FloatTensor:
1109+
assert not hasattr(TestCuda, test_name + '_gpu_half'), "Duplicated test name: " + test_name
1110+
setattr(TestCuda,
1111+
test_name + '_gpu_half',
1112+
compare_cpu_gpu(constr, arg_constr, name_inner, t,
1113+
precision, force_gpu_half=True))
11001114

11011115

11021116
if __name__ == '__main__':

torch/lib/THC/THCBlas.cu

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ double THCudaBlas_Ddot(THCState *state, int64_t n, double *x, int64_t incx, doub
4949
}
5050

5151
#ifdef CUDA_HALF_TENSOR
52-
float THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy)
52+
half THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy)
5353
{
5454
#if CUDA_VERSION >= 8000
5555
if (n == 1) {
@@ -58,22 +58,23 @@ float THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y
5858
}
5959

6060
if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
61-
int i_n = (int)n;
62-
int i_incx = (int)incx;
63-
int i_incy = (int)incy;
64-
float result;
61+
half result;
6562
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
6663
cublasSetStream(handle, THCState_getCurrentStream(state));
67-
THCublasCheck(cublasDotEx(handle, i_n, x, CUDA_R_16F, i_incx, y, CUDA_R_16F, i_incy, &result, CUDA_R_32F, CUDA_R_32F));
64+
THCublasCheck(cublasDotEx(handle, n,
65+
x, CUDA_R_16F, incx,
66+
y, CUDA_R_16F, incy,
67+
&result, CUDA_R_16F,
68+
CUDA_R_32F));
6869
return result;
69-
}
70+
}
7071

7172
THError("Cublas_Hdot only supports n, incx and incy "
7273
"up to signed integer limits: %d", INT_MAX);
73-
return 0;
74+
return THC_float2half(0);
7475
#else
7576
THError("Cublas_Hdot requires CUDA 8.0+");
76-
return 0;
77+
return THC_half2float(0);
7778
#endif
7879
}
7980
#endif
@@ -360,7 +361,7 @@ void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, i
360361
float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount)
361362
{
362363
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
363-
364+
364365
{
365366
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
366367
"with the bound [val] <= %d", INT_MAX);
@@ -420,7 +421,7 @@ void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, i
420421
cublasSetStream(handle, THCState_getCurrentStream(state));
421422
THCublasCheck(cublasDgemmStridedBatched(handle,
422423
opa, opb, (int)m, (int)n, (int)k,
423-
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
424+
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
424425
(int)batchCount));
425426
}
426427
#endif

torch/lib/THC/THCBlas.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
THC_API float THCudaBlas_Sdot(THCState *state, int64_t n, float *x, int64_t incx, float *y, int64_t incy);
99
THC_API double THCudaBlas_Ddot(THCState *state, int64_t n, double *x, int64_t incx, double *y, int64_t incy);
1010
#ifdef CUDA_HALF_TENSOR
11-
THC_API float THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy);
11+
THC_API half THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy);
1212
#endif
1313

1414
/* Level 2 */
@@ -36,7 +36,7 @@ THC_API void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char t
3636
float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB,
3737
float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount);
3838
THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
39-
double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB,
39+
double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB,
4040
double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount);
4141
#endif
4242
/* Inverse */

torch/lib/THC/generic/THCTensorMathBlas.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
2424
THCTensor_(data)(state, self), 1,
2525
THCTensor_(data)(state, src), 1);
2626
#elif defined(THC_REAL_IS_HALF)
27-
accreal result = THCudaBlas_Hdot(state,
27+
accreal result = ScalarConvert<half, accreal>::to(
28+
THCudaBlas_Hdot(state,
2829
THCTensor_(nElement)(state, self),
2930
THCTensor_(data)(state, self), 1,
30-
THCTensor_(data)(state, src), 1);
31+
THCTensor_(data)(state, src), 1));
3132
#endif
3233

3334
THCTensor_(free)(state, src);
@@ -492,14 +493,14 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
492493
ldc = result_->stride[2];
493494
}
494495

495-
if (batch1->stride[transpose_result ? 2 : 1] == 1 &&
496+
if (batch1->stride[transpose_result ? 2 : 1] == 1 &&
496497
batch1->stride[transpose_result ? 1 : 2] != 0)
497498
{
498499
transpose_batch1 = 'n';
499500
batch1_ = batch1;
500501
lda = batch1_->stride[transpose_result ? 1 : 2];
501502
}
502-
else if (batch1->stride[transpose_result ? 1 : 2] == 1 &&
503+
else if (batch1->stride[transpose_result ? 1 : 2] == 1 &&
503504
batch1->stride[transpose_result ? 2 : 1] != 0)
504505
{
505506
transpose_batch1 = 't';
@@ -513,7 +514,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
513514
lda = batch1_->stride[1];
514515
}
515516

516-
if (batch2->stride[transpose_result ? 2 : 1] == 1 &&
517+
if (batch2->stride[transpose_result ? 2 : 1] == 1 &&
517518
batch2->stride[transpose_result ? 1 : 2] != 0)
518519
{
519520
transpose_batch2 = 'n';
@@ -537,7 +538,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
537538

538539
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
539540
// Compute pointers to matrices in each batch.
540-
#if CUDA_VERSION < 8000
541+
#if CUDA_VERSION < 8000
541542
size_t matrices_size = num_batches * sizeof(real*);
542543

543544
// Copy pointers to device.
@@ -592,7 +593,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
592593
THCudaFree(state, d_matrices1);
593594
THCudaFree(state, d_matrices2);
594595
THCudaFree(state, d_result_matrices);
595-
596+
596597
#else
597598
#ifdef THC_REAL_IS_FLOAT
598599
THCudaBlas_SgemmStridedBatched(
@@ -606,7 +607,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
606607
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
607608
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
608609
beta,
609-
THCTensor_(data)(state, result_), ldc, result_->stride[0],
610+
THCTensor_(data)(state, result_), ldc, result_->stride[0],
610611
num_batches);
611612
#elif defined(THC_REAL_IS_DOUBLE)
612613
THCudaBlas_DgemmStridedBatched(
@@ -620,7 +621,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
620621
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
621622
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
622623
beta,
623-
THCTensor_(data)(state, result_), ldc, result_->stride[0],
624+
THCTensor_(data)(state, result_), ldc, result_->stride[0],
624625
num_batches);
625626
#endif
626627
#endif

0 commit comments

Comments
 (0)